Skip to content

Commit 1b5c02c

Browse files
Rhett-Yingmfbalinyxy235Ubuntu
authored
[release] cherry-pick from master and release for 2.2.1 (#7388)
Co-authored-by: Muhammed Fatih BALIN <[email protected]> Co-authored-by: Xinyu Yao <[email protected]> Co-authored-by: Ubuntu <[email protected]>
1 parent 8873fb2 commit 1b5c02c

File tree

10 files changed

+124
-40
lines changed

10 files changed

+124
-40
lines changed

examples/multigpu/graphbolt/node_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def parse_args():
399399
"--gpu-cache-size",
400400
type=int,
401401
default=0,
402-
help="The capacity of the GPU cache, the number of features to store.",
402+
help="The capacity of the GPU cache in bytes.",
403403
)
404404
parser.add_argument(
405405
"--dataset",

examples/sampling/graphbolt/pyg/node_classification_advanced.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ def parse_args():
350350
help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,"
351351
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
352352
)
353+
parser.add_argument(
354+
"--gpu-cache-size",
355+
type=int,
356+
default=0,
357+
help="The capacity of the GPU cache in bytes.",
358+
)
353359
parser.add_argument(
354360
"--sample-mode",
355361
default="sample_neighbor",
@@ -403,6 +409,12 @@ def main():
403409

404410
num_classes = dataset.tasks[0].metadata["num_classes"]
405411

412+
if args.gpu_cache_size > 0 and args.feature_device != "cuda":
413+
features._features[("node", None, "feat")] = gb.GPUCachedFeature(
414+
features._features[("node", None, "feat")],
415+
args.gpu_cache_size,
416+
)
417+
406418
train_dataloader, valid_dataloader = (
407419
create_dataloader(
408420
graph=graph,

python/dgl/graphbolt/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55

66
import torch
77
from torch.torch_version import TorchVersion
8+
9+
if TorchVersion(torch.__version__) >= "2.3.0":
10+
# [TODO][https://github.com/dmlc/dgl/issues/7387] Remove or refine below
11+
# check.
12+
# Due to https://github.com/dmlc/dgl/issues/7380, we need to check if dill
13+
# is available before using it.
14+
torch.utils.data.datapipes.utils.common.DILL_AVAILABLE = (
15+
torch.utils._import_utils.dill_available()
16+
)
17+
18+
# pylint: disable=wrong-import-position
819
from torch.utils.data import functional_datapipe
920
from torchdata.datapipes.iter import IterDataPipe
1021

@@ -342,6 +353,7 @@ class CSCFormatBase:
342353
>>> print(csc_foramt_base)
343354
... torch.tensor([1, 4, 2])
344355
"""
356+
345357
indptr: torch.Tensor = None
346358
indices: torch.Tensor = None
347359

python/dgl/graphbolt/feature_fetcher.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,12 @@ def record_stream(tensor):
8888

8989
if self.node_feature_keys and input_nodes is not None:
9090
if is_heterogeneous:
91-
for type_name, feature_names in self.node_feature_keys.items():
92-
nodes = input_nodes[type_name]
93-
if nodes is None:
91+
for type_name, nodes in input_nodes.items():
92+
if type_name not in self.node_feature_keys or nodes is None:
9493
continue
9594
if nodes.is_cuda:
9695
nodes.record_stream(torch.cuda.current_stream())
97-
for feature_name in feature_names:
96+
for feature_name in self.node_feature_keys[type_name]:
9897
node_features[
9998
(type_name, feature_name)
10099
] = record_stream(
@@ -126,21 +125,22 @@ def record_stream(tensor):
126125
if is_heterogeneous:
127126
# Convert edge type to string.
128127
original_edge_ids = {
129-
etype_tuple_to_str(key)
130-
if isinstance(key, tuple)
131-
else key: value
128+
(
129+
etype_tuple_to_str(key)
130+
if isinstance(key, tuple)
131+
else key
132+
): value
132133
for key, value in original_edge_ids.items()
133134
}
134-
for (
135-
type_name,
136-
feature_names,
137-
) in self.edge_feature_keys.items():
138-
edges = original_edge_ids.get(type_name, None)
139-
if edges is None:
135+
for type_name, edges in original_edge_ids.items():
136+
if (
137+
type_name not in self.edge_feature_keys
138+
or edges is None
139+
):
140140
continue
141141
if edges.is_cuda:
142142
edges.record_stream(torch.cuda.current_stream())
143-
for feature_name in feature_names:
143+
for feature_name in self.edge_feature_keys[type_name]:
144144
edge_features[i][
145145
(type_name, feature_name)
146146
] = record_stream(

python/dgl/graphbolt/impl/gpu_cached_feature.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,22 @@
88
__all__ = ["GPUCachedFeature"]
99

1010

11+
def nbytes(tensor):
12+
"""Returns the number of bytes to store the given tensor.
13+
14+
Needs to be defined only for torch versions less than 2.1. In torch >= 2.1,
15+
we can simply use "tensor.nbytes".
16+
"""
17+
return tensor.numel() * tensor.element_size()
18+
19+
20+
def num_cache_items(cache_capacity_in_bytes, single_item):
21+
"""Returns the number of rows to be cached."""
22+
item_bytes = nbytes(single_item)
23+
# Round up so that we never get a size of 0, unless bytes is 0.
24+
return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes
25+
26+
1127
class GPUCachedFeature(Feature):
1228
r"""GPU cached feature wrapping a fallback feature.
1329
@@ -17,8 +33,8 @@ class GPUCachedFeature(Feature):
1733
----------
1834
fallback_feature : Feature
1935
The fallback feature.
20-
cache_size : int
21-
The capacity of the GPU cache, the number of features to store.
36+
max_cache_size_in_bytes : int
37+
The capacity of the GPU cache in bytes.
2238
2339
Examples
2440
--------
@@ -42,16 +58,17 @@ class GPUCachedFeature(Feature):
4258
torch.Size([5])
4359
"""
4460

45-
def __init__(self, fallback_feature: Feature, cache_size: int):
61+
def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int):
4662
super(GPUCachedFeature, self).__init__()
4763
assert isinstance(fallback_feature, Feature), (
4864
f"The fallback_feature must be an instance of Feature, but got "
4965
f"{type(fallback_feature)}."
5066
)
5167
self._fallback_feature = fallback_feature
52-
self.cache_size = cache_size
68+
self.max_cache_size_in_bytes = max_cache_size_in_bytes
5369
# Fetching the feature dimension from the underlying feature.
5470
feat0 = fallback_feature.read(torch.tensor([0]))
71+
cache_size = num_cache_items(max_cache_size_in_bytes, feat0)
5572
self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype)
5673

5774
def read(self, ids: torch.Tensor = None):
@@ -104,11 +121,15 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
104121
updated.
105122
"""
106123
if ids is None:
124+
feat0 = value[:1]
107125
self._fallback_feature.update(value)
108-
size = min(self.cache_size, value.shape[0])
109-
self._feature.replace(
110-
torch.arange(0, size, device="cuda"),
111-
value[:size].to("cuda"),
126+
cache_size = min(
127+
num_cache_items(self.max_cache_size_in_bytes, feat0),
128+
value.shape[0],
129+
)
130+
self._feature = None # Destroy the existing cache first.
131+
self._feature = GPUCache(
132+
(cache_size,) + feat0.shape[1:], feat0.dtype
112133
)
113134
else:
114135
self._fallback_feature.update(value, ids)

script/dgl_dev.yml.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,6 @@ dependencies:
4949
- lintrunner
5050
- jupyterlab
5151
- ipywidgets
52+
- expecttest
5253
variables:
5354
DGL_HOME: __DGL_HOME__

tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,10 +1613,14 @@ def test_csc_sampling_graph_to_pinned_memory():
16131613
is_graph_pinned(graph)
16141614

16151615

1616+
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
1617+
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
16161618
@pytest.mark.parametrize("labor", [False, True])
16171619
@pytest.mark.parametrize("is_pinned", [False, True])
16181620
@pytest.mark.parametrize("nodes", [None, True])
1619-
def test_sample_neighbors_homo(labor, is_pinned, nodes):
1621+
def test_sample_neighbors_homo(
1622+
indptr_dtype, indices_dtype, labor, is_pinned, nodes
1623+
):
16201624
if is_pinned and nodes is None:
16211625
pytest.skip("Optional nodes and is_pinned is not supported together.")
16221626
"""Original graph in COO:
@@ -1630,8 +1634,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
16301634
pytest.skip("Pinning is not meaningful without a GPU.")
16311635
# Initialize data.
16321636
total_num_edges = 12
1633-
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
1634-
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
1637+
indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
1638+
indices = torch.tensor(
1639+
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
1640+
)
16351641
assert indptr[-1] == total_num_edges
16361642
assert indptr[-1] == len(indices)
16371643

@@ -1642,7 +1648,7 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
16421648

16431649
# Generate subgraph via sample neighbors.
16441650
if nodes:
1645-
nodes = torch.LongTensor([1, 3, 4]).to(F.ctx())
1651+
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype).to(F.ctx())
16461652
elif F._default_context_str != "gpu":
16471653
pytest.skip("Optional nodes is supported only for the GPU.")
16481654
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
@@ -1662,8 +1668,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
16621668
assert subgraph.original_edge_ids is None
16631669

16641670

1671+
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
1672+
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
16651673
@pytest.mark.parametrize("labor", [False, True])
1666-
def test_sample_neighbors_hetero(labor):
1674+
def test_sample_neighbors_hetero(indptr_dtype, indices_dtype, labor):
16671675
"""Original graph in COO:
16681676
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
16691677
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
@@ -1677,10 +1685,12 @@ def test_sample_neighbors_hetero(labor):
16771685
ntypes = {"n1": 0, "n2": 1}
16781686
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
16791687
total_num_edges = 9
1680-
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
1681-
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
1682-
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
1683-
node_type_offset = torch.LongTensor([0, 2, 5])
1688+
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
1689+
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
1690+
type_per_edge = torch.tensor(
1691+
[1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=indices_dtype
1692+
)
1693+
node_type_offset = torch.tensor([0, 2, 5], dtype=indices_dtype)
16841694
assert indptr[-1] == total_num_edges
16851695
assert indptr[-1] == len(indices)
16861696

@@ -1696,8 +1706,8 @@ def test_sample_neighbors_hetero(labor):
16961706

16971707
# Sample on both node types.
16981708
nodes = {
1699-
"n1": torch.tensor([0], device=F.ctx()),
1700-
"n2": torch.tensor([0], device=F.ctx()),
1709+
"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
1710+
"n2": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
17011711
}
17021712
fanouts = torch.tensor([-1, -1])
17031713
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
@@ -1725,7 +1735,7 @@ def test_sample_neighbors_hetero(labor):
17251735
assert subgraph.original_edge_ids is None
17261736

17271737
# Sample on single node type.
1728-
nodes = {"n1": torch.tensor([0], device=F.ctx())}
1738+
nodes = {"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx())}
17291739
fanouts = torch.tensor([-1, -1])
17301740
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
17311741
subgraph = sampler(nodes, fanouts)

tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):
3636
[[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True
3737
)
3838

39+
cache_size_a *= a[:1].element_size() * a[:1].numel()
40+
cache_size_b *= b[:1].element_size() * b[:1].numel()
41+
3942
feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), cache_size_a)
4043
feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), cache_size_b)
4144

@@ -94,3 +97,7 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):
9497
feat_store_a.read(),
9598
torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype).to("cuda"),
9699
)
100+
101+
# Test with different dimensionality
102+
feat_store_a.update(b)
103+
assert torch.equal(feat_store_a.read(), b.to("cuda"))

tests/python/pytorch/graphbolt/test_feature_fetcher.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,21 @@ def test_FeatureFetcher_hetero():
160160
num_layer = 2
161161
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
162162
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
163+
# "n3" is not in the sampled input nodes.
164+
node_feature_keys = {"n1": ["a"], "n2": ["a"], "n3": ["a"]}
163165
fetcher_dp = gb.FeatureFetcher(
164-
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
166+
sampler_dp, feature_store, node_feature_keys=node_feature_keys
165167
)
166-
167168
assert len(list(fetcher_dp)) == 3
168169

170+
# Do not fetch feature for "n1".
171+
node_feature_keys = {"n2": ["a"]}
172+
fetcher_dp = gb.FeatureFetcher(
173+
sampler_dp, feature_store, node_feature_keys=node_feature_keys
174+
)
175+
for mini_batch in fetcher_dp:
176+
assert ("n1", "a") not in mini_batch.node_features
177+
169178

170179
def test_FeatureFetcher_with_edges_hetero():
171180
a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
@@ -208,7 +217,11 @@ def add_node_and_edge_ids(minibatch):
208217
return data
209218

210219
features = {}
211-
keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")]
220+
keys = [
221+
("node", "n1", "a"),
222+
("edge", "n1:e1:n2", "a"),
223+
("edge", "n2:e2:n1", "a"),
224+
]
212225
features[keys[0]] = gb.TorchBasedFeature(a)
213226
features[keys[1]] = gb.TorchBasedFeature(b)
214227
feature_store = gb.BasicFeatureStore(features)
@@ -220,8 +233,15 @@ def add_node_and_edge_ids(minibatch):
220233
)
221234
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
222235
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
236+
# "n3:e3:n3" is not in the sampled edges.
237+
# Do not fetch feature for "n2:e2:n1".
238+
node_feature_keys = {"n1": ["a"]}
239+
edge_feature_keys = {"n1:e1:n2": ["a"], "n3:e3:n3": ["a"]}
223240
fetcher_dp = gb.FeatureFetcher(
224-
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
241+
converter_dp,
242+
feature_store,
243+
node_feature_keys=node_feature_keys,
244+
edge_feature_keys=edge_feature_keys,
225245
)
226246

227247
assert len(list(fetcher_dp)) == 5
@@ -230,3 +250,4 @@ def add_node_and_edge_ids(minibatch):
230250
assert len(data.edge_features) == 3
231251
for edge_feature in data.edge_features:
232252
assert edge_feature[("n1:e1:n2", "a")].size(0) == 10
253+
assert ("n2:e2:n1", "a") not in edge_feature

third_party/cccl

Submodule cccl updated 9797 files

0 commit comments

Comments
 (0)