Skip to content

Commit a66ce09

Browse files
tjeerddieMark90
andauthored
Improve subscription caching with domain model endpoint and graphql subscriptions (#663)
* Improve subscription caching with get domain model and graphql subscriptions - change get_subscription_dict to a util function and add subscription to cache when not fetched from cache. - use get_subscription_dict for get domain model endpoint and in graphql get_subscription_details * Bump version to 2.3.0rc4 * Remove to_redis in get_subscription_dict and use _generate_etag * Add docstring to get_subscription_dict and add unit tests --------- Co-authored-by: Mark90 <[email protected]>
1 parent 3bf1278 commit a66ce09

File tree

10 files changed

+95
-51
lines changed

10 files changed

+95
-51
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 2.3.0rc3
2+
current_version = 2.3.0rc4
33
commit = False
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(rc(?P<build>\d+))?

orchestrator/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
"""This is the orchestrator workflow engine."""
1515

16-
__version__ = "2.3.0rc3"
16+
__version__ = "2.3.0rc4"
1717

1818
from orchestrator.app import OrchestratorCore
1919
from orchestrator.settings import app_settings

orchestrator/api/api_v1/endpoints/subscriptions.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,10 @@
3737
SubscriptionTable,
3838
db,
3939
)
40-
from orchestrator.domain import SubscriptionModel
4140
from orchestrator.schemas import SubscriptionWorkflowListsSchema
4241
from orchestrator.schemas.subscription import SubscriptionDomainModelSchema, SubscriptionWithMetadata
4342
from orchestrator.security import authenticate
4443
from orchestrator.services.subscriptions import (
45-
_generate_etag,
46-
build_extended_domain_model,
4744
format_extended_domain_model,
4845
format_special_types,
4946
get_subscription,
@@ -52,7 +49,7 @@
5249
from orchestrator.settings import app_settings
5350
from orchestrator.types import SubscriptionLifecycle
5451
from orchestrator.utils.deprecation_logger import deprecated_endpoint
55-
from orchestrator.utils.redis import from_redis
52+
from orchestrator.utils.get_subscription_dict import get_subscription_dict
5653

5754
router = APIRouter()
5855

@@ -106,7 +103,7 @@ def _filter_statuses(filter_statuses: str | None = None) -> list[str]:
106103
"/domain-model/{subscription_id}",
107104
response_model=SubscriptionDomainModelSchema | None,
108105
)
109-
def subscription_details_by_id_with_domain_model(
106+
async def subscription_details_by_id_with_domain_model(
110107
request: Request, subscription_id: UUID, response: Response, filter_owner_relations: bool = True
111108
) -> dict[str, Any] | None:
112109
def _build_response(model: dict, etag: str) -> dict[str, Any] | None:
@@ -117,14 +114,9 @@ def _build_response(model: dict, etag: str) -> dict[str, Any] | None:
117114
filtered = format_extended_domain_model(model, filter_owner_relations=filter_owner_relations)
118115
return format_special_types(filtered)
119116

120-
if cache_response := from_redis(subscription_id):
121-
return _build_response(*cache_response)
122-
123117
try:
124-
subscription_model = SubscriptionModel.from_subscription(subscription_id)
125-
extended_model = build_extended_domain_model(subscription_model)
126-
etag = _generate_etag(extended_model)
127-
return _build_response(extended_model, etag)
118+
subscription, etag = await get_subscription_dict(subscription_id)
119+
return _build_response(subscription, etag)
128120
except ValueError as e:
129121
if str(e) == f"Subscription with id: {subscription_id}, does not exist":
130122
raise_status(HTTPStatus.NOT_FOUND, f"Subscription with id: {subscription_id}, not found")

orchestrator/graphql/resolvers/subscription.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sqlalchemy import Select, func, select
2121
from strawberry.experimental.pydantic.conversion_types import StrawberryTypeFromPydantic
2222

23+
from nwastdlib.asyncio import gather_nice
2324
from orchestrator.db import ProductTable, SubscriptionTable, db
2425
from orchestrator.db.filters import Filter
2526
from orchestrator.db.filters.subscription import (
@@ -33,7 +34,6 @@
3334
sort_subscriptions,
3435
subscription_sort_fields,
3536
)
36-
from orchestrator.domain.base import SubscriptionModel
3737
from orchestrator.graphql.pagination import Connection
3838
from orchestrator.graphql.schemas.product import ProductModelGraphql
3939
from orchestrator.graphql.schemas.subscription import SubscriptionInterface
@@ -48,7 +48,7 @@
4848
is_querying_page_data,
4949
to_graphql_result_page,
5050
)
51-
from orchestrator.types import SubscriptionLifecycle
51+
from orchestrator.utils.get_subscription_dict import get_subscription_dict
5252

5353
logger = structlog.get_logger(__name__)
5454
# Note: we can make this more fancy by adding metadata to the field annotation that indicates if a resolver
@@ -65,26 +65,34 @@
6565
def get_subscription_graphql_type(info: OrchestratorInfo, subscription_name: str) -> StrawberryTypeFromPydantic:
6666
subscription_graphql_type = info.context.graphql_models.get(subscription_name)
6767
if not subscription_graphql_type:
68-
raise GraphQLError(message=f"No graphql type found for {subscription_name}")
68+
logger.warning(message=f"No graphql type found for {subscription_name}")
69+
base_type = info.context.graphql_models.get("subscription")
70+
if not base_type:
71+
raise GraphQLError("No subscription base type found")
72+
return base_type
6973
return subscription_graphql_type
7074

7175

72-
def get_subscription_details(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface:
76+
async def get_subscription_details(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface:
77+
from orchestrator.domain import SUBSCRIPTION_MODEL_REGISTRY
7378
from orchestrator.graphql.autoregistration import graphql_subscription_name
7479

75-
subscription_details = SubscriptionModel.from_subscription(subscription.subscription_id)
76-
base_model = subscription_details.__base_type__ if subscription_details.__base_type__ else subscription_details
77-
base_subscription_details = base_model.from_other_lifecycle( # type: ignore
78-
subscription_details, SubscriptionLifecycle.INITIAL, skip_validation=True
79-
)
80-
base_subscription_details.status = subscription_details.status
81-
strawberry_type = get_subscription_graphql_type(info, graphql_subscription_name(base_model.__name__)) # type: ignore
82-
return strawberry_type.from_pydantic(base_subscription_details) # type:ignore
80+
subscription_dict_data, _ = await get_subscription_dict(subscription.subscription_id)
81+
82+
domain_model_type = SUBSCRIPTION_MODEL_REGISTRY[subscription.product.name]
83+
base_model = domain_model_type.__base_type__ or domain_model_type
84+
85+
subscription_name = graphql_subscription_name(base_model.__name__)
86+
subscription_details = base_model.model_validate(subscription_dict_data, strict=False)
87+
subscription_details._db_model = subscription
88+
89+
strawberry_type = get_subscription_graphql_type(info, subscription_name)
90+
return strawberry_type.from_pydantic(subscription_details) # type: ignore
8391

8492

85-
def format_subscription(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface:
93+
async def format_subscription(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface:
8694
if _is_subscription_detailed(info):
87-
return get_subscription_details(info, subscription)
95+
return await get_subscription_details(info, subscription)
8896

8997
strawberry_type = get_subscription_graphql_type(info, "subscription")
9098
return strawberry_type.from_pydantic(subscription) # type:ignore
@@ -94,7 +102,7 @@ async def resolve_subscription(info: OrchestratorInfo, id: UUID) -> Subscription
94102
stmt = select(SubscriptionTable).where(SubscriptionTable.subscription_id == id)
95103

96104
if subscription := db.session.scalar(stmt):
97-
return format_subscription(info, subscription)
105+
return await format_subscription(info, subscription)
98106
return None
99107

100108

@@ -127,10 +135,10 @@ async def resolve_subscriptions(
127135
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
128136
stmt = apply_range_to_statement(stmt, after, after + first + 1)
129137

130-
graphql_subscriptions = []
138+
graphql_subscriptions: list[SubscriptionInterface] = []
131139
if is_querying_page_data(info):
132140
subscriptions = db.session.scalars(stmt).all()
133-
graphql_subscriptions = [format_subscription(info, p) for p in subscriptions]
141+
graphql_subscriptions = list(await gather_nice((format_subscription(info, p) for p in subscriptions)))
134142
logger.info("Resolve subscriptions", filter_by=filter_by, total=graphql_subscriptions)
135143

136144
return to_graphql_result_page(

orchestrator/graphql/schemas/product_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def owner_subscription_resolver(
4242
stmt = select(SubscriptionTable).where(SubscriptionTable.subscription_id == root.owner_subscription_id)
4343

4444
if subscription := db.session.scalar(stmt):
45-
return format_subscription(info, subscription)
45+
return await format_subscription(info, subscription)
4646
return None
4747

4848

orchestrator/graphql/utils/get_subscription_product_blocks.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from pydantic.alias_generators import to_camel as to_lower_camel
88
from strawberry.scalars import JSON
99

10-
from orchestrator.domain.base import SubscriptionModel
1110
from orchestrator.graphql.schemas.product_block import owner_subscription_resolver
12-
from orchestrator.services.subscriptions import build_extended_domain_model
13-
from orchestrator.utils.redis import from_redis
11+
from orchestrator.utils.get_subscription_dict import get_subscription_dict
1412

1513
if TYPE_CHECKING:
1614
from orchestrator.graphql.schemas.subscription import SubscriptionInterface
@@ -62,19 +60,10 @@ def new_product_block(item: dict[str, Any]) -> Generator:
6260
pb_instance_property_keys = ("id", "parent", "owner_subscription_id", "subscription_instance_id", "in_use_by_relations")
6361

6462

65-
async def get_subscription_dict(subscription_id: UUID) -> dict:
66-
if cached_model := from_redis(subscription_id):
67-
subscription, _ = cached_model
68-
else:
69-
subscription_model = SubscriptionModel.from_subscription(subscription_id)
70-
subscription = build_extended_domain_model(subscription_model)
71-
return subscription
72-
73-
7463
async def get_subscription_product_blocks(
7564
subscription_id: UUID, tags: list[str] | None = None, product_block_instance_values: list[str] | None = None
7665
) -> list[ProductBlockInstance]:
77-
subscription = await get_subscription_dict(subscription_id)
66+
subscription, _ = await get_subscription_dict(subscription_id)
7867

7968
def to_product_block(product_block: dict[str, Any]) -> ProductBlockInstance:
8069
def is_resource_type(candidate: Any) -> bool:
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from uuid import UUID
2+
3+
from orchestrator.domain.base import SubscriptionModel
4+
from orchestrator.services.subscriptions import _generate_etag, build_extended_domain_model
5+
from orchestrator.utils.redis import from_redis
6+
7+
8+
async def get_subscription_dict(subscription_id: UUID) -> tuple[dict, str]:
9+
"""Helper function to get subscription dict by uuid from db or cache."""
10+
11+
if cached_model := from_redis(subscription_id):
12+
return cached_model # type: ignore
13+
14+
subscription_model = SubscriptionModel.from_subscription(subscription_id)
15+
subscription = build_extended_domain_model(subscription_model)
16+
etag = _generate_etag(subscription)
17+
return subscription, etag

orchestrator/utils/redis.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from orchestrator.services.subscriptions import _generate_etag
2626
from orchestrator.settings import app_settings
27-
from orchestrator.utils.json import json_dumps, json_loads
27+
from orchestrator.utils.json import PY_JSON_TYPES, json_dumps, json_loads
2828

2929
logger = get_logger(__name__)
3030

@@ -37,17 +37,19 @@ def caching_models_enabled() -> bool:
3737
return getenv("AIOCACHE_DISABLE", "0") == "0" and app_settings.CACHE_DOMAIN_MODELS
3838

3939

40-
def to_redis(subscription: dict[str, Any]) -> None:
40+
def to_redis(subscription: dict[str, Any]) -> str | None:
4141
if caching_models_enabled():
4242
logger.info("Setting cache for subscription", subscription=subscription["subscription_id"])
4343
etag = _generate_etag(subscription)
4444
cache.set(f"domain:{subscription['subscription_id']}", json_dumps(subscription), ex=ONE_WEEK)
4545
cache.set(f"domain:etag:{subscription['subscription_id']}", etag, ex=ONE_WEEK)
46-
else:
47-
logger.warning("Caching disabled, not caching subscription", subscription=subscription["subscription_id"])
46+
return etag
47+
48+
logger.warning("Caching disabled, not caching subscription", subscription=subscription["subscription_id"])
49+
return None
4850

4951

50-
def from_redis(subscription_id: UUID) -> tuple[Any, str] | None:
52+
def from_redis(subscription_id: UUID) -> tuple[PY_JSON_TYPES, str] | None:
5153
log = logger.bind(subscription_id=subscription_id)
5254
if caching_models_enabled():
5355
log.debug("Try to retrieve subscription from cache")

test/unit_tests/api/test_subscriptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def test_subscription_detail_with_in_use_by_ids_filtered_self(test_client, produ
737737
assert not response.json()["block"]["sub_block"]["in_use_by_ids"]
738738

739739

740-
@mock.patch("orchestrator.api.api_v1.endpoints.subscriptions.from_redis")
740+
@mock.patch("orchestrator.api.api_v1.endpoints.subscriptions.get_subscription_dict")
741741
def test_subscription_detail_special_fields(mock_from_redis, test_client):
742742
"""Test that a subscription with special field types is correctly serialized by Pydantic.
743743
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from os import getenv
2+
from unittest import mock
3+
from unittest.mock import Mock
4+
5+
import pytest
6+
7+
from orchestrator import app_settings
8+
from orchestrator.domain.base import SubscriptionModel
9+
from orchestrator.services.subscriptions import build_extended_domain_model
10+
from orchestrator.utils.get_subscription_dict import get_subscription_dict
11+
from orchestrator.utils.redis import to_redis
12+
13+
14+
@mock.patch.object(app_settings, "CACHE_DOMAIN_MODELS", False)
15+
@mock.patch("orchestrator.utils.get_subscription_dict._generate_etag")
16+
async def test_get_subscription_dict_db(generate_etag, generic_subscription_1):
17+
generate_etag.side_effect = Mock(return_value="etag-mock")
18+
await get_subscription_dict(generic_subscription_1)
19+
assert generate_etag.called
20+
21+
22+
@pytest.mark.skipif(
23+
not getenv("AIOCACHE_DISABLE", "0") == "0", reason="AIOCACHE must be enabled for this test to do anything"
24+
)
25+
@mock.patch("orchestrator.utils.get_subscription_dict._generate_etag")
26+
async def test_get_subscription_dict_cache(generate_etag, generic_subscription_1, cache_fixture):
27+
subscription = SubscriptionModel.from_subscription(generic_subscription_1)
28+
extended_model = build_extended_domain_model(subscription)
29+
30+
# Add domainmodel to cache
31+
to_redis(extended_model)
32+
cache_fixture.extend([f"domain:{generic_subscription_1}", f"domain:etag:{generic_subscription_1}"])
33+
34+
generate_etag.side_effect = Mock(return_value="etag-mock")
35+
await get_subscription_dict(generic_subscription_1)
36+
assert not generate_etag.called

0 commit comments

Comments
 (0)