Skip to content

Commit 37daf60

Browse files
committed
wip
1 parent 830fbfb commit 37daf60

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -614,15 +614,17 @@ def compute_shape(
614614
else:
615615
output_size.append(1)
616616
k_index += 1
617-
elif isinstance(k, torch.Tensor) and (
618-
k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1)
619-
):
617+
elif isinstance(k, torch.Tensor):
618+
# Handle tensor indexing (both 1D and multi-dimensional)
620619
input_size.popleft()
620+
# Add all dimensions of the indexing tensor to output
621621
output_size.extend(k.size())
622622
k_index += 1
623623
else:
624624
raise exc.InvalidIndexingType(k)
625-
assert len(input_size) == 0, "invalid subscript"
625+
# Advanced indexing might not consume all dimensions
626+
# Add any remaining dimensions from the input
627+
output_size.extend(input_size)
626628
return output_size
627629

628630
@staticmethod

helion/_compiler/type_propagation.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
456456
inputs_consumed = 0
457457
output_sizes = []
458458
env = CompileEnvironment.current()
459+
459460
for k in keys:
460461
if isinstance(k, LiteralType):
461462
if isinstance(k.value, (int, torch.SymInt)):
@@ -501,19 +502,33 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
501502
raise exc.DataDependentOutputShapeNotSupported(
502503
op_desc="Boolean mask indexing (tensor[boolean_mask])"
503504
)
504-
elif isinstance(k, TensorType) and k.fake_value.ndim == 1:
505+
elif isinstance(k, TensorType):
506+
# Handle tensor indexing (both 1D and multi-dimensional)
507+
# For advanced indexing, multiple tensor indices are broadcast together
508+
# and the first one determines the output dimensions
505509
inputs_consumed += 1
506-
output_sizes.append(k.fake_value.size(0))
510+
# Add all dimensions of the tensor for multi-dimensional indexing
511+
for dim in range(k.fake_value.ndim):
512+
output_sizes.append(k.fake_value.size(dim))
507513
elif k.contains_type(TileIndexType):
508514
raise exc.OverpackedTile(k)
509515
else:
510516
raise exc.InvalidIndexingType(k)
511-
if inputs_consumed != self.fake_value.ndim:
517+
# Advanced indexing with tensors can consume fewer dimensions than the tensor has
518+
# Only check for consuming too many dimensions
519+
if inputs_consumed > self.fake_value.ndim:
512520
raise exc.RankMismatch(
513521
self.fake_value.ndim,
514522
inputs_consumed,
515-
f"tensor shape: {tuple(self.fake_value.shape)}",
523+
f"tensor shape: {tuple(self.fake_value.shape)}, consumed {inputs_consumed} dimensions",
516524
)
525+
526+
# Add any remaining dimensions from the original tensor
527+
# This handles cases like tensor[idx] where tensor is multi-dimensional
528+
# and idx is a tensor that only indexes the first dimension
529+
for dim in range(inputs_consumed, self.fake_value.ndim):
530+
output_sizes.append(self.fake_value.size(dim))
531+
517532
return output_sizes
518533

519534
def propagate_setitem(

0 commit comments

Comments
 (0)