Skip to content

Commit 870b10f

Browse files
committed
support standard temporal sampling and updated PLC parameters
1 parent aa91e27 commit 870b10f

File tree

6 files changed

+143
-60
lines changed

6 files changed

+143
-60
lines changed

python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
batch_size: int = 16, # Refers to number of edges per batch.
6060
compression: Optional[str] = None,
6161
local_seeds_per_call: Optional[int] = None,
62+
temporal_comparison: Optional[str] = None,
6263
**kwargs,
6364
):
6465
"""
@@ -128,12 +129,19 @@ def __init__(
128129
all workers. If not provided, it will be automatically
129130
calculated.
130131
See cugraph_pyg.sampler.BaseDistributedSampler.
132+
temporal_comparison: str (optional, default='<=', i.e. monotonically decreasing)
133+
The comparison operator for temporal sampling (>, <, >=, <=, last).
134+
Note that this should be 'last' for temporal_strategy='last'.
135+
See cugraph_pyg.sampler.BaseDistributedSampler.
131136
**kwargs
132137
Other keyword arguments passed to the superclass.
133138
"""
134139

135140
subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type)
136141

142+
if temporal_comparison is None:
143+
temporal_comparison = "<="
144+
137145
if not directed:
138146
subgraph_type = torch_geometric.sampler.base.SubgraphType.induced
139147
warnings.warn(
@@ -172,17 +180,15 @@ def __init__(
172180

173181
is_temporal = (edge_label_time is not None) and (time_attr is not None)
174182

183+
if (edge_label_time is None) != (time_attr is None):
184+
warnings.warn(
185+
"Edge-based temporal sampling requires that both edge_label_time and time_attr are provided. Defaulting to non-temporal sampling."
186+
)
187+
175188
if weight_attr is not None:
176189
graph_store._set_weight_attr((feature_store, weight_attr))
177190
if is_temporal:
178-
# TODO Confirm that time is an edge attribute
179-
# TODO Add support for time override (see rapidsai/cugraph#5263)
180191
graph_store._set_etime_attr((feature_store, time_attr))
181-
warnings.warn(
182-
"Temporal sampling in cuGraph-PyG is currently only forward in time"
183-
" instead of the expected backward in time. This will be fixed in a"
184-
" future release."
185-
)
186192

187193
if isinstance(num_neighbors, dict):
188194
sorted_keys, _, _ = graph_store._numeric_edge_types
@@ -209,6 +215,7 @@ def __init__(
209215
biased=(weight_attr is not None),
210216
heterogeneous=(not graph_store.is_homogeneous),
211217
temporal=is_temporal,
218+
temporal_comparison=temporal_comparison,
212219
vertex_type_offsets=graph_store._vertex_offset_array,
213220
num_edge_types=len(graph_store.get_all_edge_attrs()),
214221
),

python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
batch_size: int = 16,
5757
compression: Optional[str] = None,
5858
local_seeds_per_call: Optional[int] = None,
59+
temporal_comparison: Optional[str] = None,
5960
**kwargs,
6061
):
6162
"""
@@ -121,12 +122,19 @@ def __init__(
121122
all workers. If not provided, it will be automatically
122123
calculated.
123124
See cugraph_pyg.sampler.BaseDistributedSampler.
125+
temporal_comparison: str (optional, default='<=', i.e. monotonically decreasing)
126+
The comparison operator for temporal sampling (>, <, >=, <=, last).
127+
Note that this should be 'last' for temporal_strategy='last'.
128+
See cugraph_pyg.sampler.BaseDistributedSampler.
124129
**kwargs
125130
Other keyword arguments passed to the superclass.
126131
"""
127132

128133
subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type)
129134

135+
if temporal_comparison is None:
136+
temporal_comparison = "<="
137+
130138
if not directed:
131139
subgraph_type = torch_geometric.sampler.base.SubgraphType.induced
132140
warnings.warn(
@@ -166,15 +174,18 @@ def __init__(
166174
is_temporal = time_attr is not None
167175

168176
if is_temporal:
169-
# TODO Confirm that time is an edge attribute
170-
# TODO Add support for time override (see rapidsai/cugraph#5263)
171177
graph_store._set_etime_attr((feature_store, time_attr))
172178

173-
warnings.warn(
174-
"Temporal sampling in cuGraph-PyG is currently only forward in time"
175-
" instead of the expected backward in time. This will be fixed in a"
176-
" future release."
177-
)
179+
if input_time is None:
180+
input_type, input_nodes, _ = (
181+
torch_geometric.loader.utils.get_input_nodes(
182+
data, input_nodes, None
183+
)
184+
)
185+
if input_type is None:
186+
input_type = list(graph_store._vertex_offsets.keys())[0]
187+
# will assume the time attribute exists for nodes as well
188+
input_time = feature_store[input_type, time_attr, None][input_nodes]
178189

179190
if weight_attr is not None:
180191
graph_store._set_weight_attr((feature_store, weight_attr))
@@ -204,6 +215,7 @@ def __init__(
204215
biased=(weight_attr is not None),
205216
heterogeneous=(not graph_store.is_homogeneous),
206217
temporal=is_temporal,
218+
temporal_comparison=temporal_comparison,
207219
vertex_type_offsets=graph_store._vertex_offset_array,
208220
num_edge_types=len(graph_store.get_all_edge_attrs()),
209221
),

python/cugraph-pyg/cugraph_pyg/loader/node_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
if input_id is None
117117
else input_id,
118118
node=input_nodes,
119-
time=None,
119+
time=input_time,
120120
input_type=input_type,
121121
)
122122

0 commit comments

Comments
 (0)