@@ -575,8 +575,13 @@ def compute_shape(
575575 input_size = collections .deque (tensor .size ())
576576 output_size = []
577577 env = CompileEnvironment .current ()
578+
579+ tensors = [k for k in index if isinstance (k , torch .Tensor )]
580+ broadcast_shape = env .tensor_indexer_broadcast_shape (tensors )
581+ first_broadcast_tensor_idx : int | None = None
582+
578583 k_index = 0
579- for k in index :
584+ for position , k in enumerate ( index ) :
580585 if k is None :
581586 output_size .append (1 )
582587 elif isinstance (k , int ):
@@ -617,11 +622,13 @@ def compute_shape(
617622 else :
618623 output_size .append (1 )
619624 k_index += 1
620- elif isinstance (k , torch .Tensor ) and (
621- k .ndim == 1 or (len (index ) == 1 and tensor .ndim == 1 )
622- ):
623- input_size .popleft ()
624- output_size .extend (k .size ())
625+ elif isinstance (k , torch .Tensor ):
626+ base_dim = input_size .popleft ()
627+ if broadcast_shape is None :
628+ output_size .extend (env .tensor_indexer_dims (k , base_dim ))
629+ elif first_broadcast_tensor_idx is None :
630+ output_size .extend (broadcast_shape )
631+ first_broadcast_tensor_idx = position
625632 k_index += 1
626633 else :
627634 raise exc .InvalidIndexingType (k )
@@ -667,13 +674,154 @@ def create(
667674 output_size = SubscriptIndexing .compute_shape (fake_value , index , state )
668675 env = CompileEnvironment .current ()
669676 dtype = env .triton_index_type ()
677+ all_tensors = [k for k in index if isinstance (k , torch .Tensor )]
678+ broadcast_shape = env .tensor_indexer_broadcast_shape (all_tensors )
679+ tensor_shapes = [list (t .size ()) for t in all_tensors ]
680+ first_tensor_idx = 0
681+ tensor_count = 0
670682 if dtype == "tl.int32" and SubscriptIndexing ._needs_int64 (fake_value ):
671683 raise exc .IndexOffsetOutOfRangeForInt32 (env .index_dtype )
672684
673685 def _is_size_one (size : int | torch .SymInt ) -> bool :
674686 return env .known_equal (size , 1 )
675687
676688 k_index = 0
689+
690+ def tensor_index_source_and_mask (
691+ index_elem : torch .Tensor ,
692+ index_var : str ,
693+ pos : int ,
694+ ) -> tuple [str , int | None ]:
695+ tile_block_id = env .get_tile_index_tensor_block_id (index_elem )
696+ index_source = (
697+ state .codegen .index_var (tile_block_id )
698+ if tile_block_id is not None
699+ else index_var
700+ )
701+ mask_block_id = tile_block_id or (
702+ env .get_block_id (output_size [pos ]) if pos < len (output_size ) else None
703+ )
704+ return index_source , mask_block_id
705+
706+ def handle_cartesian_tensor (
707+ position : int , index_elem : torch .Tensor , index_var : str
708+ ) -> bool :
709+ nonlocal first_tensor_idx , output_idx , tensor_count , k_index
710+
711+ assert broadcast_shape is not None
712+ dims = len (broadcast_shape )
713+ is_cartesian = (
714+ dims >= 2
715+ and len (tensor_shapes ) == dims
716+ and all (
717+ len (shape ) == 1
718+ or sum (1 for dim in shape if env .size_hint (dim ) != 1 ) <= 1
719+ for shape in tensor_shapes
720+ )
721+ )
722+ if not is_cartesian :
723+ return False
724+
725+ original_length = _get_padded_iota_original_length (state , position )
726+ if tensor_count == 0 :
727+ first_tensor_idx = output_idx
728+ output_idx += dims
729+
730+ axis_pos = first_tensor_idx + tensor_count
731+ expand_axis = tile_strategy .expand_str (output_size , axis_pos )
732+ index_values .append (f"({ index_var } ){ expand_axis } " )
733+
734+ if original_length is not None :
735+ mask_values .setdefault (
736+ f"(({ index_var } < { original_length } ){ expand_axis } )"
737+ )
738+
739+ tensor_count += 1
740+ k_index += 1
741+ return True
742+
743+ def handle_broadcast_tensor (index_elem : torch .Tensor , index_var : str ) -> bool :
744+ nonlocal first_tensor_idx , output_idx , tensor_count , k_index
745+
746+ assert broadcast_shape is not None
747+ if tensor_count == 0 :
748+ first_tensor_idx = output_idx
749+ output_idx += len (broadcast_shape )
750+
751+ shape = (
752+ tensor_shapes [tensor_count ]
753+ if tensor_count < len (tensor_shapes )
754+ else [1 ]
755+ )
756+ shape_size = len (shape )
757+ non_bcast_dims = sum (1 for dim in shape if env .size_hint (dim ) != 1 )
758+ is_single_dim = non_bcast_dims <= 1
759+
760+ offset = max (0 , len (broadcast_shape ) - shape_size )
761+ non_one_positions = [
762+ i for i , dim in enumerate (shape ) if env .size_hint (dim ) != 1
763+ ]
764+ expand_pos = first_tensor_idx + offset
765+ if is_single_dim and shape_size > 0 :
766+ rel_pos = non_one_positions [0 ] if non_one_positions else shape_size - 1
767+ expand_pos += rel_pos
768+
769+ if output_size :
770+ expand_pos = max (0 , min (expand_pos , len (output_size ) - 1 ))
771+ else :
772+ expand_pos = 0
773+
774+ # Number of dimensions to process for tensor indexing expansion
775+ width = (
776+ 1
777+ if is_single_dim
778+ else min (shape_size , max (0 , len (output_size ) - expand_pos ))
779+ )
780+
781+ if width <= 1 :
782+ expand_pos_str = (
783+ tile_strategy .expand_str (output_size , expand_pos )
784+ if index_elem .ndim == 1
785+ else ""
786+ )
787+ index_source , mask_block_id = tensor_index_source_and_mask (
788+ index_elem ,
789+ index_var ,
790+ expand_pos ,
791+ )
792+ expand = (
793+ expand_pos_str
794+ if expand_pos_str is not None
795+ else tile_strategy .expand_str (output_size , expand_pos )
796+ )
797+ index_values .append (f"({ index_source } ){ expand } " )
798+ if tensor_count == 0 and mask_block_id is not None :
799+ mask_var = state .codegen .mask_var (mask_block_id )
800+ if mask_var and not _is_size_one (
801+ fake_value .size (len (index_values ) - 1 )
802+ ):
803+ mask_values .setdefault (f"({ mask_var } ){ expand } " )
804+ else :
805+ index_values .append (f"({ index_var } )" )
806+
807+ if tensor_count == 0 :
808+ for pos in [expand_pos + d for d in range (width )]:
809+ if pos >= len (output_size ):
810+ continue
811+ block_idx = env .get_block_id (output_size [pos ])
812+ if block_idx is None :
813+ continue
814+ expand_str = tile_strategy .expand_str (output_size , pos )
815+ mask_var = state .codegen .mask_var (block_idx )
816+ if mask_var and not _is_size_one (
817+ fake_value .size (len (index_values ) - 1 )
818+ ):
819+ mask_values .setdefault (f"({ mask_var } ){ expand_str } " )
820+
821+ tensor_count += 1
822+ k_index += 1
823+ return True
824+
677825 for n , k in enumerate (index ):
678826 if k is None :
679827 output_idx += 1
@@ -752,40 +900,36 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
752900 index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
753901 output_idx += 1
754902 k_index += 1
755- elif isinstance (k , torch .Tensor ) and k .ndim == 1 :
756- expand = tile_strategy .expand_str (output_size , output_idx )
903+ elif isinstance (k , torch .Tensor ):
757904 ast_index = state .ast_args [1 ]
758905 assert isinstance (ast_index , (list , tuple ))
759- assert len (ast_index ) == len (index )
760906 index_var = state .codegen .lift (ast_index [n ], prefix = "index" ).id
761- index_values .append (f"({ index_var } ){ expand } " )
762- if (block_idx := env .get_block_id (output_size [output_idx ])) is not None :
763- if mask := state .codegen .mask_var (block_idx ):
764- mask_values .setdefault (f"({ mask } ){ expand } " )
765- # Check if this index comes from a padded hl.arange and generate mask
766- if (
767- original_length := _get_padded_iota_original_length (state , n )
768- ) is not None :
769- mask_values .setdefault (f"({ index_var } < { original_length } ){ expand } " )
770- output_idx += 1
771- k_index += 1
772- elif (
773- isinstance (k , torch .Tensor ) and len (index ) == 1 and fake_value .ndim == 1
774- ):
775- # TODO(jansel): combine this case with the above
776- ast_index = state .ast_args [1 ]
777- assert isinstance (ast_index , (list , tuple ))
778- assert len (ast_index ) == 1
779- index_var = state .codegen .lift (ast_index [0 ], prefix = "index" ).id
780- index_values .append (index_var )
781- output_idx += k .ndim
782- for n , s in enumerate (output_size ):
783- if (block_idx := env .get_block_id (s )) is not None and (
784- mask := state .codegen .mask_var (block_idx )
907+
908+ if broadcast_shape is not None and (
909+ handle_cartesian_tensor (n , k , index_var )
910+ or handle_broadcast_tensor (k , index_var )
911+ ):
912+ continue
913+
914+ index_source , mask_block_id = tensor_index_source_and_mask (
915+ k , index_var , output_idx
916+ )
917+
918+ expand = (
919+ tile_strategy .expand_str (output_size , output_idx )
920+ if k .ndim < len (output_size )
921+ else ""
922+ )
923+ index_values .append (f"({ index_source } ){ expand } " )
924+ if mask_block_id is not None :
925+ mask_var = state .codegen .mask_var (mask_block_id )
926+ if mask_var and not _is_size_one (
927+ fake_value .size (len (index_values ) - 1 )
785928 ):
786- mask_values .setdefault (
787- f"({ mask } ){ tile_strategy .expand_str (output_size , n )} "
788- )
929+ mask_values .setdefault (f"({ mask_var } ){ expand } " )
930+
931+ output_idx += k .ndim
932+ tensor_count += 1
789933 k_index += 1
790934 else :
791935 raise exc .InvalidIndexingType (type (k ))
0 commit comments