66import torch
77from torch import nn
88
9+ from .cugraph_base import CuGraphBaseConv
10+
911try :
10- from pylibcugraphops import make_fg_csr_hg , make_mfg_csr_hg
11- from pylibcugraphops .torch . autograd import (
12+ from pylibcugraphops . pytorch import SampledHeteroCSC , StaticHeteroCSC
13+ from pylibcugraphops .pytorch . operators import (
1214 agg_hg_basis_n2n_post as RelGraphConvAgg ,
1315 )
16+
17+ HAS_PYLIBCUGRAPHOPS = True
1418except ImportError :
15- has_pylibcugraphops = False
16- else :
17- has_pylibcugraphops = True
19+ HAS_PYLIBCUGRAPHOPS = False
1820
1921
20- class CuGraphRelGraphConv (nn . Module ):
22+ class CuGraphRelGraphConv (CuGraphBaseConv ):
2123 r"""An accelerated relational graph convolution layer from `Modeling
2224 Relational Data with Graph Convolutional Networks
2325 <https://arxiv.org/abs/1703.06103>`__ that leverages the highly-optimized
@@ -26,7 +28,8 @@ class CuGraphRelGraphConv(nn.Module):
2628 See :class:`dgl.nn.pytorch.conv.RelGraphConv` for mathematical model.
2729
2830 This module depends on :code:`pylibcugraphops` package, which can be
29- installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
31+ installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.
32+ :code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.
3033
3134 .. note::
3235 This is an **experimental** feature.
@@ -92,10 +95,11 @@ def __init__(
9295 dropout = 0.0 ,
9396 apply_norm = False ,
9497 ):
95- if has_pylibcugraphops is False :
98+ if HAS_PYLIBCUGRAPHOPS is False :
9699 raise ModuleNotFoundError (
97- f"{ self .__class__ .__name__ } requires pylibcugraphops >= 23.02 "
98- f"to be installed."
100+ f"{ self .__class__ .__name__ } requires pylibcugraphops=23.04. "
101+ f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`."
102+ f"pylibcugraphops requires Python 3.8 or 3.10."
99103 )
100104 super ().__init__ ()
101105 self .in_feat = in_feat
@@ -176,53 +180,36 @@ def forward(self, g, feat, etypes, max_in_degree=None):
176180 torch.Tensor
177181 New node features. Shape: :math:`(|V|, D_{out})`.
178182 """
179- # Create csc-representation and cast etypes to int32.
180183 offsets , indices , edge_ids = g .adj_tensors ("csc" )
181184 edge_types_perm = etypes [edge_ids .long ()].int ()
182185
183- # Create cugraph-ops graph.
184186 if g .is_block :
185187 if max_in_degree is None :
186188 max_in_degree = g .in_degrees ().max ().item ()
187189
188190 if max_in_degree < self .MAX_IN_DEGREE_MFG :
189- _graph = make_mfg_csr_hg (
190- g .dstnodes (),
191+ _graph = SampledHeteroCSC (
191192 offsets ,
192193 indices ,
194+ edge_types_perm ,
193195 max_in_degree ,
194196 g .num_src_nodes (),
195- n_node_types = 0 ,
196- n_edge_types = self .num_rels ,
197- out_node_types = None ,
198- in_node_types = None ,
199- edge_types = edge_types_perm ,
197+ self .num_rels ,
200198 )
201199 else :
202- offsets_fg = torch .empty (
203- g .num_src_nodes () + 1 ,
204- dtype = offsets .dtype ,
205- device = offsets .device ,
206- )
207- offsets_fg [: offsets .numel ()] = offsets
208- offsets_fg [offsets .numel () :] = offsets [- 1 ]
209-
210- _graph = make_fg_csr_hg (
200+ offsets_fg = self .pad_offsets (offsets , g .num_src_nodes () + 1 )
201+ _graph = StaticHeteroCSC (
211202 offsets_fg ,
212203 indices ,
213- n_node_types = 0 ,
214- n_edge_types = self .num_rels ,
215- node_types = None ,
216- edge_types = edge_types_perm ,
204+ edge_types_perm ,
205+ self .num_rels ,
217206 )
218207 else :
219- _graph = make_fg_csr_hg (
208+ _graph = StaticHeteroCSC (
220209 offsets ,
221210 indices ,
222- n_node_types = 0 ,
223- n_edge_types = self .num_rels ,
224- node_types = None ,
225- edge_types = edge_types_perm ,
211+ edge_types_perm ,
212+ self .num_rels ,
226213 )
227214
228215 h = RelGraphConvAgg (
0 commit comments