@@ -456,6 +456,7 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
456456 inputs_consumed = 0
457457 output_sizes = []
458458 env = CompileEnvironment .current ()
459+
459460 for k in keys :
460461 if isinstance (k , LiteralType ):
461462 if isinstance (k .value , (int , torch .SymInt )):
@@ -501,19 +502,33 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
501502 raise exc .DataDependentOutputShapeNotSupported (
502503 op_desc = "Boolean mask indexing (tensor[boolean_mask])"
503504 )
504- elif isinstance (k , TensorType ) and k .fake_value .ndim == 1 :
505+ elif isinstance (k , TensorType ):
506+ # Handle tensor indexing (both 1D and multi-dimensional)
507+ # For advanced indexing, multiple tensor indices are broadcast together
508+ # and the first one determines the output dimensions
505509 inputs_consumed += 1
506- output_sizes .append (k .fake_value .size (0 ))
510+ # Add all dimensions of the tensor for multi-dimensional indexing
511+ for dim in range (k .fake_value .ndim ):
512+ output_sizes .append (k .fake_value .size (dim ))
507513 elif k .contains_type (TileIndexType ):
508514 raise exc .OverpackedTile (k )
509515 else :
510516 raise exc .InvalidIndexingType (k )
511- if inputs_consumed != self .fake_value .ndim :
517+ # Advanced indexing with tensors can consume fewer dimensions than the tensor has
518+ # Only check for consuming too many dimensions
519+ if inputs_consumed > self .fake_value .ndim :
512520 raise exc .RankMismatch (
513521 self .fake_value .ndim ,
514522 inputs_consumed ,
515- f"tensor shape: { tuple (self .fake_value .shape )} " ,
523+ f"tensor shape: { tuple (self .fake_value .shape )} , consumed { inputs_consumed } dimensions " ,
516524 )
525+
526+ # Add any remaining dimensions from the original tensor
527+ # This handles cases like tensor[idx] where tensor is multi-dimensional
528+ # and idx is a tensor that only indexes the first dimension
529+ for dim in range (inputs_consumed , self .fake_value .ndim ):
530+ output_sizes .append (self .fake_value .size (dim ))
531+
517532 return output_sizes
518533
519534 def propagate_setitem (
0 commit comments