Skip to content

Commit 52b18e4

Browse files
authored
Merge pull request #169 from neo4j/visualize-from-neo4j-driver
Allow `Driver` as argument to `from_neo4j`
2 parents c01c10f + a7c6cbf commit 52b18e4

File tree

6 files changed

+116
-15
lines changed

6 files changed

+116
-15
lines changed

changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
## New features
1010

11+
* Allow passing a `neo4j.Driver` instance as input to `from_neo4j`, in which case the driver will be used internally to fetch the graph data using a simple query
12+
1113

1214
## Bug fixes
1315

docs/source/integration.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,16 @@ Once you have installed the additional dependency, you can use the :doc:`from_ne
184184
to import query results from Neo4j.
185185

186186
The ``from_neo4j`` method takes one mandatory positional parameter:
187-
188-
* A ``result`` representing the query result either in form of `neo4j.graph.Graph` or `neo4j.Result`.
187+
A ``data`` argument representing either a query result in the shape of a ``neo4j.graph.Graph`` or ``neo4j.Result``, or a
188+
``neo4j.Driver`` in which case a simple default query will be executed internally to retrieve the graph data.
189189

190190
We can also provide an optional ``size_property`` parameter, which should refer to a node property,
191191
and will be used to determine the sizes of the nodes in the visualization.
192192

193-
The ``node_caption`` and ``relationship_caption`` parameters are also optional, and indicate the node and relationship properties to use for the captions of each element in the visualization.
193+
The ``node_caption`` and ``relationship_caption`` parameters are also optional, and indicate the node and relationship
194+
properties to use for the captions of each element in the visualization.
195+
By default, the captions will be set to the node labels relationship types, but you can specify any property that
196+
exists on these entities.
194197

195198
The last optional property, ``node_radius_min_max``, can be used (and is used by default) to scale the node sizes for
196199
the visualization.

python-wrapper/src/neo4j_viz/gds.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from itertools import chain
45
from typing import Optional
56
from uuid import uuid4
@@ -99,6 +100,9 @@ def from_gds(
99100

100101
node_count = G.node_count()
101102
if node_count > max_node_count:
103+
warnings.warn(
104+
f"The '{G.name()}' projection's node count ({G.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed"
105+
)
102106
sampling_ratio = float(max_node_count) / node_count
103107
sample_name = f"neo4j-viz_sample_{uuid4()}"
104108
G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True)

python-wrapper/src/neo4j_viz/neo4j.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
import warnings
34
from typing import Optional, Union
45

56
import neo4j.graph
6-
from neo4j import Result
7+
from neo4j import Driver, Result, RoutingControl
78
from pydantic import BaseModel, ValidationError
89

910
from neo4j_viz.node import Node
@@ -20,14 +21,15 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) ->
2021

2122

2223
def from_neo4j(
23-
result: Union[neo4j.graph.Graph, Result],
24+
data: Union[neo4j.graph.Graph, Result, Driver],
2425
size_property: Optional[str] = None,
2526
node_caption: Optional[str] = "labels",
2627
relationship_caption: Optional[str] = "type",
2728
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
29+
row_limit: int = 10_000,
2830
) -> VisualizationGraph:
2931
"""
30-
Create a VisualizationGraph from a Neo4j Graph or Neo4j Result object.
32+
Create a VisualizationGraph from a Neo4j `Graph`, Neo4j `Result` or Neo4j `Driver`.
3133
3234
All node and relationship properties will be included in the visualization graph.
3335
If the properties are named as the fields of the `Node` or `Relationship` classes, they will be included as
@@ -36,8 +38,9 @@ def from_neo4j(
3638
3739
Parameters
3840
----------
39-
result : Union[neo4j.graph.Graph, Result]
40-
Query result either in shape of a Graph or result.
41+
data : Union[neo4j.graph.Graph, neo4j.Result, neo4j.Driver]
42+
Either a query result in the shape of a `neo4j.graph.Graph` or `neo4j.Result`, or a `neo4j.Driver` in
43+
which case a simple default query will be executed internally to retrieve the graph data.
4144
size_property : str, optional
4245
Property to use for node size, by default None.
4346
node_caption : str, optional
@@ -47,14 +50,32 @@ def from_neo4j(
4750
node_radius_min_max : tuple[float, float], optional
4851
Minimum and maximum node radius, by default (3, 60).
4952
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
53+
row_limit : int, optional
54+
Maximum number of rows to return from the query, by default 10_000.
55+
This is only used if a `neo4j.Driver` is passed as `result` argument, otherwise the limit is ignored.
5056
"""
5157

52-
if isinstance(result, Result):
53-
graph = result.graph()
54-
elif isinstance(result, neo4j.graph.Graph):
55-
graph = result
58+
if isinstance(data, Result):
59+
graph = data.graph()
60+
elif isinstance(data, neo4j.graph.Graph):
61+
graph = data
62+
elif isinstance(data, Driver):
63+
rel_count = data.execute_query(
64+
"MATCH ()-[r]->() RETURN count(r) as count",
65+
routing_=RoutingControl.READ,
66+
result_transformer_=Result.single,
67+
).get("count") # type: ignore[union-attr]
68+
if rel_count > row_limit:
69+
warnings.warn(
70+
f"Database relationship count ({rel_count}) exceeds `row_limit` ({row_limit}), so limiting will be applied. Increase the `row_limit` if needed"
71+
)
72+
graph = data.execute_query(
73+
f"MATCH (n)-[r]->(m) RETURN n,r,m LIMIT {row_limit}",
74+
routing_=RoutingControl.READ,
75+
result_transformer_=Result.graph,
76+
)
5677
else:
57-
raise ValueError(f"Invalid input type `{type(result)}`. Expected `neo4j.Graph` or `neo4j.Result`")
78+
raise ValueError(f"Invalid input type `{type(data)}`. Expected `neo4j.Graph`, `neo4j.Result` or `neo4j.Driver`")
5879

5980
all_node_field_aliases = Node.all_validation_aliases()
6081
all_rel_field_aliases = Relationship.all_validation_aliases()

python-wrapper/tests/test_gds.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
import pandas as pd
@@ -267,7 +268,13 @@ def test_from_gds_sample(gds: Any) -> None:
267268
from neo4j_viz.gds import from_gds
268269

269270
with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G:
270-
VG = from_gds(gds, G)
271+
with pytest.warns(
272+
UserWarning,
273+
match=re.escape(
274+
"The 'hello' projection's node count (11000) exceeds `max_node_count` (10000), so subsampling will be applied. Increase `max_node_count` if needed"
275+
),
276+
):
277+
VG = from_gds(gds, G)
271278

272279
assert len(VG.nodes) >= 9_500
273280
assert len(VG.nodes) <= 10_500

python-wrapper/tests/test_neo4j.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import re
12
from typing import Generator
23

34
import neo4j
45
import pytest
5-
from neo4j import Session
6+
from neo4j import Driver, Session
67

78
from neo4j_viz.neo4j import from_neo4j
89
from neo4j_viz.node import Node
@@ -201,3 +202,66 @@ def test_from_neo4j_rel_error(neo4j_session: Session) -> None:
201202
match="Error for relationship property 'caption_align' with provided input 'banana'. Reason: Input should be 'top', 'center' or 'bottom'",
202203
):
203204
from_neo4j(graph)
205+
206+
207+
@pytest.mark.requires_neo4j_and_gds
208+
def test_from_neo4j_graph_driver(neo4j_session: Session, neo4j_driver: Driver) -> None:
209+
graph = neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph()
210+
211+
# Note that this tests requires an empty Neo4j database, as it just fetches everything
212+
VG = from_neo4j(neo4j_driver)
213+
214+
sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"])
215+
node_ids: list[str] = [node.element_id for node in sorted_nodes]
216+
217+
expected_nodes = [
218+
Node(
219+
id=node_ids[0],
220+
caption="_CI_A",
221+
properties=dict(
222+
labels=["_CI_A"],
223+
name="Alice",
224+
height=20,
225+
id=42,
226+
_id=1337,
227+
caption="hello",
228+
),
229+
),
230+
Node(
231+
id=node_ids[1],
232+
caption="_CI_A:_CI_B",
233+
size=11,
234+
properties=dict(
235+
labels=["_CI_A", "_CI_B"],
236+
name="Bob",
237+
height=10,
238+
id=84,
239+
__labels=[1, 2],
240+
),
241+
),
242+
]
243+
244+
assert len(VG.nodes) == 2
245+
assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes
246+
247+
assert len(VG.relationships) == 2
248+
vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo")
249+
assert vg_rels == [
250+
(node_ids[0], node_ids[1], "KNOWS"),
251+
(node_ids[1], node_ids[0], "RELATED"),
252+
]
253+
254+
255+
@pytest.mark.requires_neo4j_and_gds
256+
def test_from_neo4j_graph_row_limit_warning(neo4j_session: Session, neo4j_driver: Driver) -> None:
257+
neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph()
258+
259+
with pytest.warns(
260+
UserWarning,
261+
match=re.escape(
262+
"Database relationship count (2) exceeds `row_limit` (1), so limiting will be applied. Increase the `row_limit` if needed"
263+
),
264+
):
265+
VG = from_neo4j(neo4j_driver, row_limit=1)
266+
267+
assert len(VG.relationships) == 1

0 commit comments

Comments
 (0)