Skip to content

Commit 62d64fd

Browse files
authored
Reduce DB queries in GQL resolver for subscriptions (product, inuseby/dependson) (#745)
* Eagerload SubscriptionTable.product in initial query to avoid lazy loading Signed-off-by: Mark90 <[email protected]> * Add dataloaders for in_use_by and depends_on subscriptions Signed-off-by: Mark90 <[email protected]> * Add custom_context_getter option to register_graphql() Signed-off-by: Mark90 <[email protected]> * Bump version to 2.7.6rc1 --------- Signed-off-by: Mark90 <[email protected]>
1 parent 2e06d00 commit 62d64fd

File tree

10 files changed

+267
-19
lines changed

10 files changed

+267
-19
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 2.7.5
2+
current_version = 2.7.6rc1
33
commit = False
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(rc(?P<build>\d+))?

orchestrator/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
"""This is the orchestrator workflow engine."""
1515

16-
__version__ = "2.7.5"
16+
__version__ = "2.7.6rc1"
1717

1818
from orchestrator.app import OrchestratorCore
1919
from orchestrator.settings import app_settings

orchestrator/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from orchestrator.domain import SUBSCRIPTION_MODEL_REGISTRY, SubscriptionModel
5151
from orchestrator.exception_handlers import problem_detail_handler
5252
from orchestrator.graphql import Mutation, Query, create_graphql_router
53+
from orchestrator.graphql.schema import ContextGetterFactory
5354
from orchestrator.graphql.schemas.subscription import SubscriptionInterface
5455
from orchestrator.graphql.types import ScalarOverrideType, StrawberryModelType
5556
from orchestrator.log_config import LOGGER_OVERRIDES
@@ -205,6 +206,7 @@ def register_graphql(
205206
graphql_models: StrawberryModelType | None = None,
206207
scalar_overrides: ScalarOverrideType | None = None,
207208
extensions: list | None = None,
209+
custom_context_getter: ContextGetterFactory | None = None,
208210
) -> None:
209211
new_router = create_graphql_router(
210212
self.auth_manager,
@@ -216,6 +218,7 @@ def register_graphql(
216218
graphql_models,
217219
scalar_overrides,
218220
extensions=extensions,
221+
custom_context_getter=custom_context_getter,
219222
)
220223
if not self.graphql_router:
221224
self.graphql_router = new_router

orchestrator/graphql/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
OrchestratorSchema,
2525
Query,
2626
create_graphql_router,
27-
get_context,
27+
default_context_getter,
2828
)
2929
from orchestrator.graphql.schemas import DEFAULT_GRAPHQL_MODELS
3030
from orchestrator.graphql.types import SCALAR_OVERRIDES
@@ -38,7 +38,7 @@
3838
"Mutation",
3939
"OrchestratorGraphqlRouter",
4040
"OrchestratorSchema",
41-
"get_context",
41+
"default_context_getter",
4242
"create_graphql_router",
4343
"EnumDict",
4444
"add_class_to_strawberry",

orchestrator/graphql/loaders/__init__.py

Whitespace-only changes.
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
from itertools import chain
2+
from typing import Any, NamedTuple
3+
from uuid import UUID
4+
5+
import structlog
6+
from sqlalchemy import Row, select
7+
from sqlalchemy import Text as SaText
8+
from sqlalchemy import cast as sa_cast
9+
from sqlalchemy.orm import aliased
10+
from strawberry.dataloader import DataLoader
11+
12+
from orchestrator.db import (
13+
ResourceTypeTable,
14+
SubscriptionInstanceTable,
15+
SubscriptionInstanceValueTable,
16+
SubscriptionTable,
17+
db,
18+
)
19+
from orchestrator.db.models import (
20+
SubscriptionInstanceRelationTable,
21+
)
22+
from orchestrator.services.subscriptions import RELATION_RESOURCE_TYPES
23+
from orchestrator.types import SubscriptionLifecycle
24+
25+
logger = structlog.get_logger(__name__)
26+
27+
28+
class Relation(NamedTuple):
29+
depends_on_sub_id: UUID
30+
in_use_by_sub_id: UUID
31+
32+
33+
def _get_instance_relations(instance_relations_query: Any) -> list[Relation]:
34+
def to_relation(row: Row[Any]) -> Relation:
35+
return Relation(row[0], row[1])
36+
37+
return [to_relation(row) for row in db.session.execute(instance_relations_query)]
38+
39+
40+
async def _get_in_use_by_instance_relations(subscription_ids: list[UUID], filter_statuses: list[str]) -> list[Relation]:
41+
"""Get in_use_by by relations through subscription instance hierarchy."""
42+
in_use_by_subscriptions = aliased(SubscriptionTable)
43+
in_use_by_instances = aliased(SubscriptionInstanceTable)
44+
depends_on_instances = aliased(SubscriptionInstanceTable)
45+
46+
query_get_in_use_by_ids = (
47+
select(depends_on_instances.subscription_id, in_use_by_instances.subscription_id)
48+
.distinct()
49+
.join(in_use_by_instances.subscription)
50+
.join(in_use_by_instances.depends_on_block_relations)
51+
.join(depends_on_instances, SubscriptionInstanceRelationTable.depends_on)
52+
.join(in_use_by_subscriptions, depends_on_instances.subscription)
53+
.filter(depends_on_instances.subscription_id.in_(set(subscription_ids)))
54+
.filter(in_use_by_instances.subscription_id != depends_on_instances.subscription_id)
55+
.filter(in_use_by_subscriptions.status.in_(filter_statuses))
56+
)
57+
58+
return _get_instance_relations(query_get_in_use_by_ids)
59+
60+
61+
async def _get_depends_on_instance_relations(
62+
subscription_ids: list[UUID], filter_statuses: list[str]
63+
) -> list[Relation]:
64+
"""Get depends_on relations through subscription instance hierarchy."""
65+
in_use_by_instances = aliased(SubscriptionInstanceTable)
66+
depends_on_instances = aliased(SubscriptionInstanceTable)
67+
depends_on_subscriptions = aliased(SubscriptionTable)
68+
69+
query_get_depends_on_ids = (
70+
select(depends_on_instances.subscription_id, in_use_by_instances.subscription_id)
71+
.distinct()
72+
.join(depends_on_instances.subscription)
73+
.join(depends_on_instances.in_use_by_block_relations)
74+
.join(in_use_by_instances, SubscriptionInstanceRelationTable.in_use_by)
75+
.join(depends_on_subscriptions, in_use_by_instances.subscription)
76+
.filter(in_use_by_instances.subscription_id.in_(set(subscription_ids)))
77+
.filter(depends_on_instances.subscription_id != in_use_by_instances.subscription_id)
78+
.filter(depends_on_subscriptions.status.in_(filter_statuses))
79+
)
80+
81+
return _get_instance_relations(query_get_depends_on_ids)
82+
83+
84+
def _get_resource_type_relations(resource_type_relations_query: Any) -> list[Relation]:
85+
def to_relation(row: Row[Any]) -> Relation:
86+
return Relation(UUID(row[0]), row[1])
87+
88+
return [to_relation(row) for row in db.session.execute(resource_type_relations_query)]
89+
90+
91+
async def _get_in_use_by_resource_type_relations(
92+
subscription_ids: list[UUID], filter_statuses: list[str]
93+
) -> list[Relation]:
94+
"""Get in_use_by relations through resource types."""
95+
logger.warning("Using legacy RELATION_RESOURCE_TYPES to find in_use_by subs")
96+
97+
in_use_by_subscriptions = aliased(SubscriptionTable)
98+
depends_on_instance_values = aliased(SubscriptionInstanceValueTable)
99+
100+
# Convert UUIDs to string
101+
unique_subscription_ids = set(map(str, subscription_ids))
102+
103+
query_get_in_use_by_ids = (
104+
select(depends_on_instance_values.value, in_use_by_subscriptions.subscription_id)
105+
.select_from(depends_on_instance_values)
106+
.join(SubscriptionInstanceTable)
107+
.join(in_use_by_subscriptions)
108+
.join(ResourceTypeTable)
109+
.filter(ResourceTypeTable.resource_type.in_(RELATION_RESOURCE_TYPES))
110+
.filter(depends_on_instance_values.value.in_(unique_subscription_ids))
111+
.filter(in_use_by_subscriptions.status.in_(filter_statuses))
112+
)
113+
114+
return _get_resource_type_relations(query_get_in_use_by_ids)
115+
116+
117+
async def _get_depends_on_resource_type_relations(
118+
subscription_ids: list[UUID], filter_statuses: list[str]
119+
) -> list[Relation]:
120+
"""Get depends_on relations through resource types."""
121+
logger.warning("Using legacy RELATION_RESOURCE_TYPES to find depends_on subs")
122+
123+
depends_on_subscriptions = aliased(SubscriptionTable)
124+
in_use_by_instances = aliased(SubscriptionInstanceTable)
125+
in_use_by_instance_values = aliased(SubscriptionInstanceValueTable)
126+
127+
unique_subscription_ids = set(subscription_ids)
128+
129+
query_get_depends_on_ids = (
130+
select(in_use_by_instance_values.value, in_use_by_instances.subscription_id)
131+
.select_from(in_use_by_instance_values)
132+
.join(in_use_by_instances)
133+
.join(
134+
depends_on_subscriptions,
135+
in_use_by_instance_values.value == sa_cast(depends_on_subscriptions.subscription_id, SaText),
136+
)
137+
.join(ResourceTypeTable)
138+
.filter(ResourceTypeTable.resource_type.in_(RELATION_RESOURCE_TYPES))
139+
.filter(in_use_by_instances.subscription_id.in_(unique_subscription_ids))
140+
.filter(depends_on_subscriptions.status.in_(filter_statuses))
141+
)
142+
143+
return _get_resource_type_relations(query_get_depends_on_ids)
144+
145+
146+
async def _get_in_use_by_relations(subscription_ids: list[UUID], filter_statuses: list[str]) -> list[Relation]:
147+
if RELATION_RESOURCE_TYPES:
148+
# Find relations through resource types
149+
resource_type_relations = await _get_in_use_by_resource_type_relations(subscription_ids, filter_statuses)
150+
else:
151+
resource_type_relations = []
152+
# Find relations through instance hierarchy
153+
instance_relations = await _get_in_use_by_instance_relations(subscription_ids, filter_statuses)
154+
return list(chain(resource_type_relations, instance_relations))
155+
156+
157+
async def _get_depends_on_relations(subscription_ids: list[UUID], filter_statuses: list[str]) -> list[Relation]:
158+
if RELATION_RESOURCE_TYPES:
159+
# Find relations through resource types
160+
resource_type_relations = await _get_depends_on_resource_type_relations(subscription_ids, filter_statuses)
161+
else:
162+
resource_type_relations = []
163+
# Find relations through instance hierarchy
164+
instance_relations = await _get_depends_on_instance_relations(subscription_ids, filter_statuses)
165+
return list(chain(resource_type_relations, instance_relations))
166+
167+
168+
async def in_use_by_subs_loader(keys: list[tuple[UUID, list[str] | None]]) -> list[list[SubscriptionTable]]:
169+
"""GraphQL dataloader to efficiently get the in_use_by SubscriptionTables for multiple subscription_ids."""
170+
subscription_ids = [key[0] for key in keys]
171+
filter_statuses: list[str] = keys[0][1] or SubscriptionLifecycle.values()
172+
173+
in_use_by_relations = await _get_in_use_by_relations(subscription_ids, filter_statuses)
174+
175+
# Retrieve SubscriptionTable for all unique inuseby ids
176+
unique_in_use_by_ids = {row.in_use_by_sub_id for row in in_use_by_relations}
177+
_in_use_by_subs = db.session.execute(
178+
select(SubscriptionTable).filter(SubscriptionTable.subscription_id.in_(unique_in_use_by_ids))
179+
).scalars()
180+
in_use_by_subs = {subscription.subscription_id: subscription for subscription in _in_use_by_subs}
181+
182+
# group (more_itertools.bucket doesn't seem to work for tuple of uuids)
183+
subscription_in_use_by_ids: dict[UUID, list[UUID]] = {}
184+
for relation in in_use_by_relations:
185+
subscription_in_use_by_ids.setdefault(relation.depends_on_sub_id, []).append(relation.in_use_by_sub_id)
186+
187+
def get_in_use_by_subs(depends_on_id: UUID) -> list[SubscriptionTable]:
188+
in_use_by_ids = subscription_in_use_by_ids.get(depends_on_id, [])
189+
return [in_use_by_sub for id_ in in_use_by_ids if (in_use_by_sub := in_use_by_subs.get(id_))]
190+
191+
# Important (as with any dataloader)
192+
# Return the list of inuseby subs in the exact same order as the ids passed to this function
193+
return [get_in_use_by_subs(subscription_id) for subscription_id in subscription_ids]
194+
195+
196+
async def depends_on_subs_loader(keys: list[tuple[UUID, list[str] | None]]) -> list[list[SubscriptionTable]]:
197+
"""GraphQL dataloader to efficiently get the depends_on SubscriptionTables for multiple subscription_ids."""
198+
subscription_ids = [key[0] for key in keys]
199+
filter_statuses: list[str] = keys[0][1] or SubscriptionLifecycle.values()
200+
201+
depends_on_relations = await _get_depends_on_relations(subscription_ids, filter_statuses)
202+
203+
# Retrieve SubscriptionTable for all unique dependson ids
204+
unique_depends_on_ids = {row.depends_on_sub_id for row in depends_on_relations}
205+
_depends_on_subs = db.session.execute(
206+
select(SubscriptionTable).filter(SubscriptionTable.subscription_id.in_(unique_depends_on_ids))
207+
).scalars()
208+
depends_on_subs = {subscription.subscription_id: subscription for subscription in _depends_on_subs}
209+
210+
# group (more_itertools.bucket doesn't seem to work for tuple of uuids)
211+
subscription_depends_on_ids: dict[UUID, list[UUID]] = {}
212+
for relation in depends_on_relations:
213+
subscription_depends_on_ids.setdefault(relation.in_use_by_sub_id, []).append(relation.depends_on_sub_id)
214+
215+
def get_depends_on_subs(in_use_by_id: UUID) -> list[SubscriptionTable]:
216+
depends_on_ids = subscription_depends_on_ids.get(in_use_by_id, [])
217+
return [depends_on_sub for id_ in depends_on_ids if (depends_on_sub := depends_on_subs.get(id_))]
218+
219+
# Important (as with any dataloader)
220+
# Return the list of dependson subs in the exact same order as the ids passed to this function
221+
return [get_depends_on_subs(subscription_id) for subscription_id in subscription_ids]
222+
223+
224+
SubsLoaderType = DataLoader[tuple[UUID, list[str] | None], list[SubscriptionTable]]

orchestrator/graphql/resolvers/subscription.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
13-
1413
from typing import cast
1514
from uuid import UUID
1615

1716
import structlog
1817
from graphql import GraphQLError
1918
from pydantic.alias_generators import to_camel as to_lower_camel
2019
from sqlalchemy import Select, func, select
20+
from sqlalchemy.orm import contains_eager
2121
from strawberry.experimental.pydantic.conversion_types import StrawberryTypeFromPydantic
2222

2323
from nwastdlib.asyncio import gather_nice
@@ -125,7 +125,16 @@ async def resolve_subscriptions(
125125
filter=pydantic_filter_by,
126126
query=query,
127127
)
128-
stmt = select(SubscriptionTable).join(ProductTable)
128+
129+
stmt = (
130+
select(SubscriptionTable)
131+
.join(ProductTable)
132+
.options(
133+
# contains_eager() is needed because .join() does not eagerload, unlike options(joinedload())
134+
# (and using joinedload() is not possible because of filter_subscriptions())
135+
contains_eager(SubscriptionTable.product),
136+
)
137+
)
129138

130139
stmt = filter_subscriptions(stmt, pydantic_filter_by, _error_handler)
131140
if query is not None:

orchestrator/graphql/schema.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from collections.abc import Callable, Iterable
1414
from http import HTTPStatus
1515
from pathlib import Path
16-
from typing import Any, Coroutine
16+
from typing import Any, Coroutine, Protocol
1717

1818
import strawberry
1919
import structlog
@@ -134,17 +134,26 @@ def process_errors(
134134
StrawberryLogger.error(error, execution_context)
135135

136136

137-
def get_context(
137+
class ContextGetterFactory(Protocol):
138+
def __call__(
139+
self,
140+
auth_manager: AuthManager,
141+
graphql_models: StrawberryModelType,
142+
broadcast_thread: ProcessDataBroadcastThread | None = None,
143+
) -> Callable[[], Coroutine[Any, Any, OrchestratorContext]]: ...
144+
145+
146+
def default_context_getter(
138147
auth_manager: AuthManager,
139148
graphql_models: StrawberryModelType,
140149
broadcast_thread: ProcessDataBroadcastThread | None = None,
141150
) -> Callable[[], Coroutine[Any, Any, OrchestratorContext]]:
142-
async def _get_context() -> OrchestratorContext:
151+
async def context_getter() -> OrchestratorContext:
143152
return OrchestratorContext(
144153
auth_manager=auth_manager, graphql_models=graphql_models, broadcast_thread=broadcast_thread
145154
)
146155

147-
return _get_context
156+
return context_getter
148157

149158

150159
def get_extensions(mutation: Any, query: Any) -> Iterable[type[SchemaExtension]]:
@@ -169,6 +178,7 @@ def create_graphql_router(
169178
graphql_models: StrawberryModelType | None = None,
170179
scalar_overrides: ScalarOverrideType | None = None,
171180
extensions: list | None = None,
181+
custom_context_getter: ContextGetterFactory | None = None,
172182
) -> OrchestratorGraphqlRouter:
173183
scalar_overrides = scalar_overrides if scalar_overrides else dict(SCALAR_OVERRIDES)
174184
models = graphql_models if graphql_models else dict(DEFAULT_GRAPHQL_MODELS)
@@ -190,8 +200,9 @@ def create_graphql_router(
190200
scalar_overrides=scalar_overrides,
191201
)
192202

203+
context_getter_factory = custom_context_getter or default_context_getter
193204
return OrchestratorGraphqlRouter(
194205
schema,
195-
context_getter=get_context(auth_manager, models, broadcast_thread),
206+
context_getter=context_getter_factory(auth_manager, models, broadcast_thread),
196207
graphiql=app_settings.SERVE_GRAPHQL_UI,
197208
)

0 commit comments

Comments
 (0)