|
1 | 1 | """Tests graph operations using v2 objects""" |
2 | 2 |
|
3 | 3 | import random |
| 4 | +from asyncio import sleep |
4 | 5 | from urllib.parse import urlparse |
5 | 6 | from uuid import UUID, uuid4 |
6 | 7 |
|
|
9 | 10 | from httpx import AsyncClient |
10 | 11 | from sqlmodel.ext.asyncio.session import AsyncSession |
11 | 12 |
|
12 | | -from lsst.cmservice.common.enums import StatusEnum |
| 13 | +from lsst.cmservice.common.enums import ManifestKind, StatusEnum |
13 | 14 | from lsst.cmservice.common.graph import ( |
14 | 15 | delete_node_from_graph, |
15 | 16 | graph_from_edge_list_v2, |
16 | 17 | processable_graph_nodes, |
17 | 18 | validate_graph, |
18 | 19 | ) |
19 | 20 | from lsst.cmservice.common.types import AnyAsyncSession |
20 | | -from lsst.cmservice.db.campaigns_v2 import Edge |
| 21 | +from lsst.cmservice.db.campaigns_v2 import Edge, Node |
21 | 22 |
|
22 | 23 | pytestmark = pytest.mark.asyncio(loop_scope="module") |
23 | 24 | """All tests in this module will run in the same event loop.""" |
@@ -318,3 +319,55 @@ async def test_delete_node_from_graph( |
318 | 319 | commit=True, |
319 | 320 | ) |
320 | 321 | ... |
| 322 | + |
| 323 | + |
| 324 | +async def test_graph_contraction( |
| 325 | + aclient: AsyncClient, session: AsyncSession, test_campaign_groups: str |
| 326 | +) -> None: |
| 327 | + """Tests the stepwise contraction of a graph.""" |
| 328 | + edge_list = [Edge.model_validate(edge) for edge in (await aclient.get(test_campaign_groups)).json()] |
| 329 | + graph = await graph_from_edge_list_v2(edge_list, session, node_view="model") |
| 330 | + |
| 331 | + # "prepare" all the "step" nodes in the graph |
| 332 | + for node in graph: |
| 333 | + model: Node = graph.nodes(data="model")[node] |
| 334 | + if model.kind is not ManifestKind.step: |
| 335 | + continue |
| 336 | + x = await aclient.patch( |
| 337 | + f"/cm-service/v2/nodes/{model.id}", |
| 338 | + headers={"Content-Type": "application/merge-patch+json"}, |
| 339 | + json={"status": "ready"}, |
| 340 | + ) |
| 341 | + assert x.is_success |
| 342 | + while not len((await aclient.get(x.headers["StatusUpdate"])).json()): # noqa: ASYNC110 |
| 343 | + await sleep(1.0) |
| 344 | + |
| 345 | + # get a fresh copy of the graph |
| 346 | + edge_list = [Edge.model_validate(edge) for edge in (await aclient.get(test_campaign_groups)).json()] |
| 347 | + graph = await graph_from_edge_list_v2(edge_list, session, node_view="model") |
| 348 | + |
| 349 | + # Use the node contraction function to manipulate the graph by removing |
| 350 | + # node_a by contraction into start_node |
| 351 | + |
| 352 | + # second-tier graph nodes will be of type "group" or "collect_groups" but |
| 353 | + # in the "simple" graph view we won't have access to the metadata that |
| 354 | + # includes the node's "parent" step. |
| 355 | + # Instead, we focus on "group" nodes first, as a middle-out contraction. |
| 356 | + # - by definition, the predecessor node for a group is that group's step |
| 357 | + g2 = graph.copy() |
| 358 | + for node in graph: |
| 359 | + model = graph.nodes(data="model")[node] |
| 360 | + if model.kind is not ManifestKind.group: |
| 361 | + continue |
| 362 | + step = list(graph.predecessors(node)).pop() |
| 363 | + g2 = nx.contracted_nodes(g2, step, node, self_loops=False, store_contraction_as=None) |
| 364 | + |
| 365 | + # repeat for the collect steps |
| 366 | + g3 = g2.copy() |
| 367 | + for node in g2: |
| 368 | + model = graph.nodes(data="model")[node] |
| 369 | + if model.kind is not ManifestKind.collect_groups: |
| 370 | + continue |
| 371 | + step = list(g2.predecessors(node)).pop() |
| 372 | + g3 = nx.contracted_nodes(g3, step, node, self_loops=False, store_contraction_as=None) |
| 373 | + assert validate_graph(g3) |
0 commit comments