Skip to content

Commit 5f0438b

Browse files
committed
wip(graph): contract/fold prepared graph
1 parent 021744c commit 5f0438b

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

src/lsst/cmservice/common/graph.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,33 @@ async def delete_node_from_graph(
561561

562562
if commit:
563563
await session.commit()
564+
565+
566+
async def contract_step_nodes(graph: nx.DiGraph) -> nx.DiGraph:
567+
"""Manipulates a graph with prepared steps by contracting the dynamic
568+
second-tier elements into the parent step. The contracted elements are not
569+
preserved in the graph.
570+
571+
Parameters
572+
----------
573+
graph : networkx.DiGraph
574+
A graph object where each node is a full Node Model.
575+
"""
576+
# contract any group nodes in the graph into their step
577+
g2 = graph.copy()
578+
for node in graph:
579+
model: Node = graph.nodes(data="model")[node]
580+
if model.kind is not ManifestKind.group:
581+
continue
582+
step = list(graph.predecessors(node)).pop()
583+
g2 = nx.contracted_nodes(g2, step, node, self_loops=False, store_contraction_as=None)
584+
585+
# repeat for the collect steps
586+
g3 = g2.copy()
587+
for node in g2:
588+
model = graph.nodes(data="model")[node]
589+
if model.kind is not ManifestKind.collect_groups:
590+
continue
591+
step = list(g2.predecessors(node)).pop()
592+
g3 = nx.contracted_nodes(g3, step, node, self_loops=False, store_contraction_as=None)
593+
return g3

tests/v2/test_graph.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests graph operations using v2 objects"""
22

33
import random
4+
from asyncio import sleep
45
from urllib.parse import urlparse
56
from uuid import UUID, uuid4
67

@@ -9,15 +10,15 @@
910
from httpx import AsyncClient
1011
from sqlmodel.ext.asyncio.session import AsyncSession
1112

12-
from lsst.cmservice.common.enums import StatusEnum
13+
from lsst.cmservice.common.enums import ManifestKind, StatusEnum
1314
from lsst.cmservice.common.graph import (
1415
delete_node_from_graph,
1516
graph_from_edge_list_v2,
1617
processable_graph_nodes,
1718
validate_graph,
1819
)
1920
from lsst.cmservice.common.types import AnyAsyncSession
20-
from lsst.cmservice.db.campaigns_v2 import Edge
21+
from lsst.cmservice.db.campaigns_v2 import Edge, Node
2122

2223
pytestmark = pytest.mark.asyncio(loop_scope="module")
2324
"""All tests in this module will run in the same event loop."""
@@ -318,3 +319,55 @@ async def test_delete_node_from_graph(
318319
commit=True,
319320
)
320321
...
322+
323+
324+
async def test_graph_contraction(
325+
aclient: AsyncClient, session: AsyncSession, test_campaign_groups: str
326+
) -> None:
327+
"""Tests the stepwise contraction of a graph."""
328+
edge_list = [Edge.model_validate(edge) for edge in (await aclient.get(test_campaign_groups)).json()]
329+
graph = await graph_from_edge_list_v2(edge_list, session, node_view="model")
330+
331+
# "prepare" all the "step" nodes in the graph
332+
for node in graph:
333+
model: Node = graph.nodes(data="model")[node]
334+
if model.kind is not ManifestKind.step:
335+
continue
336+
x = await aclient.patch(
337+
f"/cm-service/v2/nodes/{model.id}",
338+
headers={"Content-Type": "application/merge-patch+json"},
339+
json={"status": "ready"},
340+
)
341+
assert x.is_success
342+
while not len((await aclient.get(x.headers["StatusUpdate"])).json()): # noqa: ASYNC110
343+
await sleep(1.0)
344+
345+
# get a fresh copy of the graph
346+
edge_list = [Edge.model_validate(edge) for edge in (await aclient.get(test_campaign_groups)).json()]
347+
graph = await graph_from_edge_list_v2(edge_list, session, node_view="model")
348+
349+
# Use the node contraction function to manipulate the graph by removing
350+
# node_a by contraction into start_node
351+
352+
# second-tier graph nodes will be of type "group" or "collect_groups" but
353+
# in the "simple" graph view we won't have access to the metadata that
354+
# includes the node's "parent" step.
355+
# Instead, we focus on "group" nodes first, as a middle-out contraction.
356+
# - by definition, the predecessor node for a group is that group's step
357+
g2 = graph.copy()
358+
for node in graph:
359+
model = graph.nodes(data="model")[node]
360+
if model.kind is not ManifestKind.group:
361+
continue
362+
step = list(graph.predecessors(node)).pop()
363+
g2 = nx.contracted_nodes(g2, step, node, self_loops=False, store_contraction_as=None)
364+
365+
# repeat for the collect steps
366+
g3 = g2.copy()
367+
for node in g2:
368+
model = graph.nodes(data="model")[node]
369+
if model.kind is not ManifestKind.collect_groups:
370+
continue
371+
step = list(g2.predecessors(node)).pop()
372+
g3 = nx.contracted_nodes(g3, step, node, self_loops=False, store_contraction_as=None)
373+
assert validate_graph(g3)

0 commit comments

Comments
 (0)