Skip to content
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4a93761
Serialize/deserialize component state
NathalieCharbel May 15, 2025
f940506
Run pipeline until/Resume pipeline from
NathalieCharbel May 15, 2025
12d2e26
Remove in memory storage support for pipeline state
NathalieCharbel Jun 4, 2025
8a00015
Add pipeline run_id and ability to save and load state from json file
NathalieCharbel Jun 4, 2025
5f0c892
Add unit tests
NathalieCharbel Jun 4, 2025
22722b8
Update changelog and docs
NathalieCharbel Jun 4, 2025
9a57046
Ruff
NathalieCharbel Jun 4, 2025
d1f7389
Remove state management for component
NathalieCharbel Jun 10, 2025
8bc8ac0
Remove resume_from and run_until and reuse existing run interface
NathalieCharbel Jun 10, 2025
f76e5ca
Add dump and load to InMemoryStore
NathalieCharbel Jun 10, 2025
2cc5b99
Ruff
NathalieCharbel Jun 10, 2025
2a819ea
Add ability ro load and dump state by run_id
NathalieCharbel Jun 11, 2025
472eaf1
Allow orchestrator to run use a run_id from previous run
NathalieCharbel Jun 11, 2025
13254dd
Refactor pipeline and validate loaded state
NathalieCharbel Jun 11, 2025
a9413d3
Refactor pipeline run_id management
NathalieCharbel Jun 13, 2025
64bdc66
Ensure previous run_ids are kept in store
NathalieCharbel Jun 13, 2025
d020a7e
Ensure resume run with different run_ids
NathalieCharbel Jun 13, 2025
c97bb97
Cleanup stores
NathalieCharbel Jun 16, 2025
74e7db0
Ensure proper handling of previous run_ids
NathalieCharbel Jun 16, 2025
0ca4c60
Update changelog and docs
NathalieCharbel Jun 16, 2025
2c091c2
Fix orchestrator's way of handling tasks on complete and transitions …
NathalieCharbel Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default.
- Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call).
- Added pipeline state management with `run_until`, `resume_from`, `dump_state`, and `load_state` methods, enabling pipeline execution checkpointing and resumption.

### Fixed

Expand Down
30 changes: 30 additions & 0 deletions docs/source/user_guide_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,36 @@ can be added to the visualization by setting `hide_unused_outputs` to `False`:
webbrowser.open("pipeline_full.html")


*************************
Pipeline State Management
*************************

Pipelines support checkpointing and resumption through state management features:

.. code:: python

# Run pipeline until a specific component
state = await pipeline.run_until(data, stop_after="component_name", state_file="state.json")

# Resume pipeline from a specific component
result = await pipeline.resume_from(state, data, start_from="component_name")

# Alternatively, load state from file
result = await pipeline.resume_from(None, data, start_from="component_name", state_file="state.json")

The state contains:
- Pipeline configuration (parameter mappings between components and validation state)
- Execution results (outputs from completed components stored in the ResultStore)
- Final pipeline results from previous runs
- Component-specific state (interface available but not yet implemented by components)

This enables:
- Checkpointing long-running pipelines
- Debugging pipeline execution
- Resuming failed pipelines from the last successful component
- Comparing different component implementations with deterministic inputs by saving the state before the component and reusing it, avoiding non-deterministic results from preceding components


************************
Adding an Event Callback
************************
Expand Down
29 changes: 25 additions & 4 deletions src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import uuid
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, AsyncGenerator
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional

from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.exceptions import (
Expand All @@ -46,16 +46,29 @@ class Orchestrator:
- finding the next tasks to execute
- building the inputs for each task
- calling the run method on each task
- optionally stopping after a specified component
- optionally starting from a specified component

Once a TaskNode is done, it calls the `on_task_complete` callback
that will save the results, find the next tasks to be executed
(checking that all dependencies are met), and run them.

Partial execution is supported through:
- stop_after: Stop execution after this component completes
- start_from: Start execution from this component instead of roots
"""

def __init__(self, pipeline: Pipeline):
def __init__(
self,
pipeline: Pipeline,
stop_after: Optional[str] = None,
start_from: Optional[str] = None,
):
self.pipeline = pipeline
self.event_notifier = EventNotifier(pipeline.callbacks)
self.run_id = str(uuid.uuid4())
self.stop_after = stop_after
self.start_from = start_from

async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
"""Get inputs and run a specific task. Once the task is done,
Expand Down Expand Up @@ -129,7 +142,10 @@ async def on_task_complete(
await self.add_result_for_component(
task.name, res_to_save, is_final=task.is_leaf()
)
# then get the next tasks to be executed
# stop if this is the stop_after node
if self.stop_after and task.name == self.stop_after:
return
# otherwise, get the next tasks to be executed
# and run them in //
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])

Expand Down Expand Up @@ -266,7 +282,12 @@ async def run(self, data: dict[str, Any]) -> None:
will handle the task dependencies.
"""
await self.event_notifier.notify_pipeline_started(self.run_id, data)
tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
# start from a specific node if requested, otherwise from roots
if self.start_from:
start_nodes = [self.pipeline.get_node_by_name(self.start_from)]
else:
start_nodes = self.pipeline.roots()
tasks = [self.run_task(root, data) for root in start_nodes]
await asyncio.gather(*tasks)
await self.event_notifier.notify_pipeline_finished(
self.run_id, await self.pipeline.get_final_results(self.run_id)
Expand Down
61 changes: 58 additions & 3 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, Optional, AsyncGenerator
from typing import Any, Optional, AsyncGenerator, Dict
import asyncio

from neo4j_graphrag.utils.logging import prettify
Expand Down Expand Up @@ -563,19 +563,74 @@ async def event_stream(event: Event) -> None:
if event_queue_getter_task and not event_queue_getter_task.done():
event_queue_getter_task.cancel()

async def run(self, data: dict[str, Any]) -> PipelineResult:
async def run(
self,
data: dict[str, Any],
from_: Optional[str] = None,
until: Optional[str] = None,
) -> PipelineResult:
"""Run the pipeline, optionally from a specific component or until a specific component.

Args:
data (dict[str, Any]): The input data for the pipeline
from_ (str | None, optional): If provided, start execution from this component. Defaults to None.
until (str | None, optional): If provided, stop execution after this component. Defaults to None.

Returns:
PipelineResult: The result of the pipeline execution
"""
logger.debug("PIPELINE START")
start_time = default_timer()
self.invalidate()
self.validate_input_data(data)
orchestrator = Orchestrator(self)

# create orchestrator with appropriate start_from and stop_after params
orchestrator = Orchestrator(self, stop_after=until, start_from=from_)

logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
await orchestrator.run(data)

end_time = default_timer()
logger.debug(
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
)

return PipelineResult(
run_id=orchestrator.run_id,
result=await self.get_final_results(orchestrator.run_id),
)

def dump_state(self, run_id: str) -> Dict[str, Any]:
"""Dump the current state of the pipeline to a serializable dictionary.

Args:
run_id: The run_id that was used when the pipeline was executed

Returns:
Dict[str, Any]: A serializable dictionary containing the pipeline state
"""
pipeline_state: Dict[str, Any] = {
"run_id": run_id,
"store": self.store.dump(),
"final_results": self.final_results.dump(),
"is_validated": self.is_validated,
}
return pipeline_state

def load_state(self, state: Dict[str, Any]) -> None:
"""Load pipeline state from a serialized dictionary.

Args:
state (dict[str, Any]): Previously serialized pipeline state
"""
# load pipeline state attributes
if "is_validated" in state:
self.is_validated = state["is_validated"]

# load store data
if "store" in state:
self.store.load(state["store"])

# load final results
if "final_results" in state:
self.final_results.load(state["final_results"])
58 changes: 58 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,28 @@ async def add_result_for_component(
async def get_result_for_component(self, run_id: str, task_name: str) -> Any:
return await self.get(self.get_key(run_id, task_name))

@abc.abstractmethod
def dump(self, run_id: str) -> dict[str, Any]:
"""Dump the store state for a specific run_id to a serializable dictionary.

Args:
run_id (str): The run_id to dump data for

Returns:
dict[str, Any]: A serializable dictionary containing the store state for the run_id
"""
pass

@abc.abstractmethod
def load(self, run_id: str, state: dict[str, Any]) -> None:
"""Load the store state for a specific run_id from a serializable dictionary.

Args:
run_id (str): The run_id to load data for
state (dict[str, Any]): A serializable dictionary containing the store state
"""
pass


class InMemoryStore(ResultStore):
"""Simple in-memory store.
Expand All @@ -115,3 +137,39 @@ def all(self) -> dict[str, Any]:

def empty(self) -> None:
self._data = {}

def dump(self, run_id: str) -> dict[str, Any]:
"""Dump the store state for a specific run_id to a serializable dictionary.

Args:
run_id (str): The run_id to dump data for

Returns:
dict[str, Any]: A serializable dictionary containing the store state for the run_id
"""
# filter data by run_id prefix
run_id_prefix = f"{run_id}:"
filtered_data = {
key: value
for key, value in self._data.items()
if key.startswith(run_id_prefix)
}
return filtered_data

def load(self, run_id: str, state: dict[str, Any]) -> None:
"""Load the store state for a specific run_id from a serializable dictionary.

Args:
run_id (str): The run_id to load data for
state (dict[str, Any]): A serializable dictionary containing the store state
"""
# clear existing data for this run_id first
run_id_prefix = f"{run_id}:"
keys_to_remove = [
key for key in self._data.keys() if key.startswith(run_id_prefix)
]
for key in keys_to_remove:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here we are removing all results from a previous run with this run_id, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

del self._data[key]

# load the new state data
self._data.update(state)
17 changes: 16 additions & 1 deletion tests/unit/experimental/pipeline/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from .components import ComponentMultiply, ComponentMultiplyWithContext, IntResultModel
from .components import (
ComponentMultiply,
ComponentMultiplyWithContext,
IntResultModel,
StatefulComponent,
)


def test_component_inputs() -> None:
Expand Down Expand Up @@ -87,3 +92,13 @@ class WrongComponent(Component):
"You must implement either `run` or `run_with_context` in Component 'WrongComponent'"
in str(e)
)


def test_stateful_component_serialize_and_load_state() -> None:
c = StatefulComponent()
c.counter = 42
state = c.serialize_state()
assert state == {"counter": 42}
c.counter = 0
c.load_state({"counter": 99})
assert c.counter == 99
Loading