Skip to content

Commit f60a200

Browse files
committed
revert changes from other branch
1 parent 870b10f commit f60a200

File tree

2 files changed

+3
-17
lines changed

2 files changed

+3
-17
lines changed

python/cugraph-pyg/cugraph_pyg/data/feature_store.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def __make_wg_tensor(self, tensor, ix=None):
8888
(torch.float64, 1),
8989
(torch.int32, 2),
9090
(torch.int64, 3),
91-
(torch.int16, 4),
92-
(torch.float16, 5),
93-
(torch.int8, 6),
91+
(torch.bool, 4),
9492
]:
9593
dtypes[k] = v
9694
dtype_ids[v] = k

python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,9 @@ def test_feature_store_basic_api(single_pytorch_worker):
5252
)
5353
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
5454
@pytest.mark.sg
55-
@pytest.mark.parametrize(
56-
"dtype",
57-
[
58-
torch.float32,
59-
torch.float16,
60-
torch.int8,
61-
torch.int16,
62-
torch.int32,
63-
torch.int64,
64-
torch.float64,
65-
],
66-
)
67-
def test_feature_store_basic_api_types(single_pytorch_worker, dtype):
55+
def test_feature_store_basic_api_float(single_pytorch_worker):
6856
features = torch.arange(0, 2000)
69-
features = features.reshape((features.numel() // 100, 100)).to(dtype)
57+
features = features.reshape((features.numel() // 100, 100)).to(torch.float32)
7058

7159
whole_store = FeatureStore()
7260
whole_store["node", "fea", None] = features

0 commit comments

Comments
 (0)