Skip to content

Commit 3e856cb

Browse files
committed
[Test] Fix indexing benchmark
ghstack-source-id: 6c87867 Pull-Request: #1468
1 parent 5be3736 commit 3e856cb

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

benchmarks/compile/compile_td_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,10 @@ def test_compile_assign_and_add_stack(mode, benchmark):
309309
def test_compile_indexing(mode, dict_type, index_type, benchmark):
310310
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
311311
td = TensorDict(
312-
{"a": torch.arange(100), "b": {"c": torch.arange(100)}},
312+
{
313+
"a": torch.arange(100, device=device),
314+
"b": {"c": torch.arange(100, device=device)},
315+
},
313316
batch_size=[100],
314317
device=device,
315318
)
@@ -329,7 +332,7 @@ def test_compile_indexing(mode, dict_type, index_type, benchmark):
329332
else:
330333
idx = slice(None, None, 2)
331334
if index_type == "tensor":
332-
idx = torch.tensor(range(*idx.indices(10)))
335+
idx = torch.tensor(range(*idx.indices(10)), device=device)
333336

334337
func(td, idx)
335338
func(td, idx)

0 commit comments

Comments
 (0)