@@ -575,8 +575,13 @@ def compute_shape(
575575 input_size = collections .deque (tensor .size ())
576576 output_size = []
577577 env = CompileEnvironment .current ()
578+
579+ tensors = [k for k in index if isinstance (k , torch .Tensor )]
580+ broadcast_shape = env .tensor_indexer_broadcast_shape (tensors )
581+ first_broadcast_tensor_idx : int | None = None
582+
578583 k_index = 0
579- for k in index :
584+ for position , k in enumerate ( index ) :
580585 if k is None :
581586 output_size .append (1 )
582587 elif isinstance (k , int ):
@@ -617,11 +622,13 @@ def compute_shape(
617622 else :
618623 output_size .append (1 )
619624 k_index += 1
620- elif isinstance (k , torch .Tensor ) and (
621- k .ndim == 1 or (len (index ) == 1 and tensor .ndim == 1 )
622- ):
623- input_size .popleft ()
624- output_size .extend (k .size ())
625+ elif isinstance (k , torch .Tensor ):
626+ base_dim = input_size .popleft ()
627+ if broadcast_shape is None :
628+ output_size .extend (env .tensor_indexer_dims (k , base_dim ))
629+ elif first_broadcast_tensor_idx is None :
630+ output_size .extend (broadcast_shape )
631+ first_broadcast_tensor_idx = position
625632 k_index += 1
626633 else :
627634 raise exc .InvalidIndexingType (k )
@@ -667,13 +674,115 @@ def create(
667674 output_size = SubscriptIndexing .compute_shape (fake_value , index , state )
668675 env = CompileEnvironment .current ()
669676 dtype = env .triton_index_type ()
677+ all_tensors = [k for k in index if isinstance (k , torch .Tensor )]
678+ broadcast_shape = env .tensor_indexer_broadcast_shape (all_tensors )
679+ tensor_shapes = [list (t .size ()) for t in all_tensors ]
680+ first_tensor_idx = 0
681+ tensor_count = 0
670682 if dtype == "tl.int32" and SubscriptIndexing ._needs_int64 (fake_value ):
671683 raise exc .IndexOffsetOutOfRangeForInt32 (env .index_dtype )
672684
673685 def _is_size_one (size : int | torch .SymInt ) -> bool :
674686 return env .known_equal (size , 1 )
675687
676688 k_index = 0
689+
690+ def tensor_index_source_and_mask (
691+ index_elem : torch .Tensor , index_var : str , pos : int
692+ ) -> tuple [str , int | None ]:
693+ tile_id = env .get_tile_index_tensor_block_id (index_elem )
694+ src = state .codegen .index_var (tile_id ) if tile_id else index_var
695+ mask_id = tile_id or (
696+ env .get_block_id (output_size [pos ]) if pos < len (output_size ) else None
697+ )
698+ return src , mask_id
699+
700+ def handle_broadcast_tensor (
701+ position : int , index_elem : torch .Tensor , index_var : str
702+ ) -> None :
703+ """Handle tensor index with broadcast shape (cartesian or general)."""
704+ nonlocal first_tensor_idx , output_idx , tensor_count , k_index
705+ assert broadcast_shape is not None
706+ dims = len (broadcast_shape )
707+ if tensor_count == 0 :
708+ first_tensor_idx = output_idx
709+ output_idx += dims
710+
711+ shape = (
712+ tensor_shapes [tensor_count ]
713+ if tensor_count < len (tensor_shapes )
714+ else [1 ]
715+ )
716+ # Cartesian: multiple 1D tensors each contributing one dim
717+ is_cartesian = (
718+ dims >= 2
719+ and len (tensor_shapes ) == dims
720+ and all (
721+ len (s ) == 1 or sum (1 for d in s if env .size_hint (d ) != 1 ) <= 1
722+ for s in tensor_shapes
723+ )
724+ )
725+ # Find position(s) where this tensor contributes
726+ offset = max (0 , dims - len (shape ))
727+ contrib = [
728+ first_tensor_idx + offset + i
729+ for i , d in enumerate (shape )
730+ if env .size_hint (d ) != 1
731+ ]
732+ pos = (
733+ first_tensor_idx + tensor_count
734+ if is_cartesian
735+ else (
736+ contrib [0 ]
737+ if contrib
738+ else max (
739+ 0 ,
740+ min (
741+ first_tensor_idx + offset + len (shape ) - 1 ,
742+ len (output_size ) - 1 ,
743+ ),
744+ )
745+ )
746+ )
747+ # Generate index expression
748+ if is_cartesian or len (contrib ) <= 1 :
749+ src , mask_id = tensor_index_source_and_mask (index_elem , index_var , pos )
750+ expand = (
751+ tile_strategy .expand_str (output_size , pos )
752+ if index_elem .ndim == 1
753+ else ""
754+ )
755+ index_values .append (f"({ src } ){ expand } " )
756+ if (
757+ tensor_count == 0
758+ and mask_id
759+ and (mv := state .codegen .mask_var (mask_id ))
760+ ):
761+ if not _is_size_one (fake_value .size (len (index_values ) - 1 )):
762+ mask_values .setdefault (f"({ mv } ){ expand } " )
763+ else :
764+ index_values .append (f"({ index_var } )" )
765+ if tensor_count == 0 :
766+ for p in contrib :
767+ if p < len (output_size ) and (
768+ bid := env .get_block_id (output_size [p ])
769+ ):
770+ if (mv := state .codegen .mask_var (bid )) and not _is_size_one (
771+ fake_value .size (len (index_values ) - 1 )
772+ ):
773+ mask_values .setdefault (
774+ f"({ mv } ){ tile_strategy .expand_str (output_size , p )} "
775+ )
776+ # Padded iota mask
777+ if (
778+ orig_len := _get_padded_iota_original_length (state , position )
779+ ) is not None :
780+ mask_values .setdefault (
781+ f"(({ index_var } < { orig_len } ){ tile_strategy .expand_str (output_size , first_tensor_idx + tensor_count )} )"
782+ )
783+ tensor_count += 1
784+ k_index += 1
785+
677786 for n , k in enumerate (index ):
678787 if k is None :
679788 output_idx += 1
@@ -752,40 +861,35 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
752861 index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
753862 output_idx += 1
754863 k_index += 1
755- elif isinstance (k , torch .Tensor ) and k .ndim == 1 :
756- expand = tile_strategy .expand_str (output_size , output_idx )
864+ elif isinstance (k , torch .Tensor ):
757865 ast_index = state .ast_args [1 ]
758866 assert isinstance (ast_index , (list , tuple ))
759- assert len (ast_index ) == len (index )
760867 index_var = state .codegen .lift (ast_index [n ], prefix = "index" ).id
761- index_values .append (f"({ index_var } ){ expand } " )
762- if (block_idx := env .get_block_id (output_size [output_idx ])) is not None :
763- if mask := state .codegen .mask_var (block_idx ):
764- mask_values .setdefault (f"({ mask } ){ expand } " )
765- # Check if this index comes from a padded hl.arange and generate mask
766- if (
767- original_length := _get_padded_iota_original_length (state , n )
768- ) is not None :
769- mask_values .setdefault (f"({ index_var } < { original_length } ){ expand } " )
770- output_idx += 1
771- k_index += 1
772- elif (
773- isinstance (k , torch .Tensor ) and len (index ) == 1 and fake_value .ndim == 1
774- ):
775- # TODO(jansel): combine this case with the above
776- ast_index = state .ast_args [1 ]
777- assert isinstance (ast_index , (list , tuple ))
778- assert len (ast_index ) == 1
779- index_var = state .codegen .lift (ast_index [0 ], prefix = "index" ).id
780- index_values .append (index_var )
781- output_idx += k .ndim
782- for n , s in enumerate (output_size ):
783- if (block_idx := env .get_block_id (s )) is not None and (
784- mask := state .codegen .mask_var (block_idx )
868+
869+ # Use broadcast handling for: multiple tensors, or single tensor with ndim > 1
870+ if broadcast_shape is not None and (len (all_tensors ) > 1 or k .ndim > 1 ):
871+ handle_broadcast_tensor (n , k , index_var )
872+ continue
873+
874+ index_source , mask_block_id = tensor_index_source_and_mask (
875+ k , index_var , output_idx
876+ )
877+
878+ expand = (
879+ tile_strategy .expand_str (output_size , output_idx )
880+ if k .ndim < len (output_size )
881+ else ""
882+ )
883+ index_values .append (f"({ index_source } ){ expand } " )
884+ if mask_block_id is not None :
885+ mask_var = state .codegen .mask_var (mask_block_id )
886+ if mask_var and not _is_size_one (
887+ fake_value .size (len (index_values ) - 1 )
785888 ):
786- mask_values .setdefault (
787- f"({ mask } ){ tile_strategy .expand_str (output_size , n )} "
788- )
889+ mask_values .setdefault (f"({ mask_var } ){ expand } " )
890+
891+ output_idx += k .ndim
892+ tensor_count += 1
789893 k_index += 1
790894 else :
791895 raise exc .InvalidIndexingType (type (k ))
0 commit comments