@@ -13,8 +13,7 @@ namespace graphbolt {
1313namespace ops {
1414
1515torch::Tensor IndexSelect (torch::Tensor input, torch::Tensor index) {
16- if (input.is_pinned () &&
17- (index.is_pinned () || index.device ().type () == c10::DeviceType::CUDA)) {
16+ if (utils::is_on_gpu (index) && input.is_pinned ()) {
1817 GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE (
1918 c10::DeviceType::CUDA, " UVAIndexSelect" ,
2019 { return UVAIndexSelectImpl (input, index); });
@@ -26,9 +25,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
2625 torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
2726 TORCH_CHECK (
2827 indices.sizes ().size () == 1 , " IndexSelectCSC only supports 1d tensors" );
29- if (utils::is_accessible_from_gpu (indptr) &&
30- utils::is_accessible_from_gpu (indices) &&
31- utils::is_accessible_from_gpu (nodes)) {
28+ if (utils::is_on_gpu (nodes) && utils::is_accessible_from_gpu (indptr) &&
29+ utils::is_accessible_from_gpu (indices)) {
3230 GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE (
3331 c10::DeviceType::CUDA, " IndexSelectCSCImpl" ,
3432 { return IndexSelectCSCImpl (indptr, indices, nodes); });
0 commit comments