Skip to content

Commit 9029b5a

Browse files
committed
fix
1 parent e695ab4 commit 9029b5a

File tree

6 files changed

+387
-46
lines changed

6 files changed

+387
-46
lines changed

helion/_compiler/compile_environment.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,106 @@ def cached_create_unbacked_symint(
272272
self._symint_cache[key] = result
273273
return result
274274

275+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
276+
"""Return the originating ``tile.index`` block id if present."""
277+
return getattr(tensor, "_tile_index_block_id", None)
278+
279+
def tensor_indexer_broadcast_shape(
280+
self, tensors: typing.Sequence[torch.Tensor] | None
281+
) -> list[int | torch.SymInt] | None:
282+
"""Compute a shared broadcast shape for tensor indexers when needed.
283+
284+
Returns:
285+
- None: when there are no tensor indexers, or all indexers already
286+
carry ``_tile_index_block_id`` (tile-origin indices should not
287+
participate in broadcast/cartesian expansion).
288+
- list[int | SymInt]: the broadcast shape to apply when mixing
289+
multiple non-tile tensor indexers.
290+
"""
291+
tensor_list = [t for t in tensors or [] if isinstance(t, torch.Tensor)]
292+
if not tensor_list or all(
293+
self.get_tile_index_tensor_block_id(t) for t in tensor_list
294+
):
295+
return None
296+
297+
shapes = [list(t.size()) for t in tensor_list]
298+
299+
# Special case: multiple 1D tensors form a Cartesian product
300+
if all(len(s) == 1 for s in shapes) and len(shapes) > 1:
301+
return [s[0] for s in shapes]
302+
303+
# General broadcasting case
304+
max_ndim = max(len(s) for s in shapes)
305+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
306+
return [
307+
next((d for d in dims if self.size_hint(d) != 1), 1)
308+
for dims in zip(*padded, strict=True)
309+
]
310+
311+
def tensor_indexer_dims(
312+
self,
313+
indexer_tensor: torch.Tensor,
314+
base_dim_size: int | torch.SymInt,
315+
) -> list[int | torch.SymInt]:
316+
"""Return dims contributed by a tensor indexer (non-broadcast case)."""
317+
dims = list(indexer_tensor.size())
318+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
319+
320+
# Multi-dimensional indexer - return full shape
321+
if len(non_broadcast_dims) > 1:
322+
return typing.cast("list[int | torch.SymInt]", dims)
323+
324+
# Try to find block_id from various sources
325+
block_id = (
326+
self.get_tile_index_tensor_block_id(indexer_tensor)
327+
or (self.get_block_id(base_dim_size) if base_dim_size is not None else None)
328+
or (
329+
self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None
330+
)
331+
)
332+
333+
if block_id:
334+
return [self.block_sizes[block_id].var]
335+
if non_broadcast_dims:
336+
return typing.cast("list[int | torch.SymInt]", non_broadcast_dims)
337+
return [1]
338+
339+
def new_index_result(
340+
self,
341+
tensor: torch.Tensor,
342+
output_shape: typing.Sequence[int | torch.SymInt],
343+
) -> torch.Tensor:
344+
"""Create a new tensor for indexing/view ops while preserving tile index provenance.
345+
346+
The block_id is inferred from:
347+
1) Existing provenance on ``tensor`` (``_tile_index_block_id``), otherwise
348+
2) The first non-broadcast dimension in ``output_shape`` that maps to a block_id.
349+
"""
350+
block_id = self.get_tile_index_tensor_block_id(tensor)
351+
if block_id is None:
352+
non_broadcast = [
353+
i for i, s in enumerate(output_shape) if self.size_hint(s) != 1
354+
]
355+
block_ids = {self.get_block_id(output_shape[i]) for i in non_broadcast}
356+
block_ids.discard(None)
357+
if len(block_ids) == 1:
358+
block_id = block_ids.pop()
359+
360+
resolved_shape = list(output_shape)
361+
if block_id is not None:
362+
non_broadcast = [
363+
i for i, s in enumerate(resolved_shape) if self.size_hint(s) != 1
364+
]
365+
if len(non_broadcast) == 1:
366+
resolved_shape[non_broadcast[0]] = self.block_sizes[block_id].var
367+
elif len(non_broadcast) > 1:
368+
block_id = None
369+
370+
result = tensor.new_empty(resolved_shape)
371+
if block_id is not None:
372+
result._tile_index_block_id = block_id # type: ignore[attr-defined]
373+
return result
374+
275375
def to_fake(self, obj: object, origin: Origin) -> object:
276376
if obj is None:
277377
return None

helion/_compiler/indexing_strategy.py

Lines changed: 180 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,13 @@ def compute_shape(
575575
input_size = collections.deque(tensor.size())
576576
output_size = []
577577
env = CompileEnvironment.current()
578+
579+
tensors = [k for k in index if isinstance(k, torch.Tensor)]
580+
broadcast_shape = env.tensor_indexer_broadcast_shape(tensors)
581+
first_broadcast_tensor_idx: int | None = None
582+
578583
k_index = 0
579-
for k in index:
584+
for position, k in enumerate(index):
580585
if k is None:
581586
output_size.append(1)
582587
elif isinstance(k, int):
@@ -617,11 +622,13 @@ def compute_shape(
617622
else:
618623
output_size.append(1)
619624
k_index += 1
620-
elif isinstance(k, torch.Tensor) and (
621-
k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1)
622-
):
623-
input_size.popleft()
624-
output_size.extend(k.size())
625+
elif isinstance(k, torch.Tensor):
626+
base_dim = input_size.popleft()
627+
if broadcast_shape is None:
628+
output_size.extend(env.tensor_indexer_dims(k, base_dim))
629+
elif first_broadcast_tensor_idx is None:
630+
output_size.extend(broadcast_shape)
631+
first_broadcast_tensor_idx = position
625632
k_index += 1
626633
else:
627634
raise exc.InvalidIndexingType(k)
@@ -667,13 +674,154 @@ def create(
667674
output_size = SubscriptIndexing.compute_shape(fake_value, index, state)
668675
env = CompileEnvironment.current()
669676
dtype = env.triton_index_type()
677+
all_tensors = [k for k in index if isinstance(k, torch.Tensor)]
678+
broadcast_shape = env.tensor_indexer_broadcast_shape(all_tensors)
679+
tensor_shapes = [list(t.size()) for t in all_tensors]
680+
first_tensor_idx = 0
681+
tensor_count = 0
670682
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
671683
raise exc.IndexOffsetOutOfRangeForInt32(env.index_dtype)
672684

673685
def _is_size_one(size: int | torch.SymInt) -> bool:
674686
return env.known_equal(size, 1)
675687

676688
k_index = 0
689+
690+
def tensor_index_source_and_mask(
691+
index_elem: torch.Tensor,
692+
index_var: str,
693+
pos: int,
694+
) -> tuple[str, int | None]:
695+
tile_block_id = env.get_tile_index_tensor_block_id(index_elem)
696+
index_source = (
697+
state.codegen.index_var(tile_block_id)
698+
if tile_block_id is not None
699+
else index_var
700+
)
701+
mask_block_id = tile_block_id or (
702+
env.get_block_id(output_size[pos]) if pos < len(output_size) else None
703+
)
704+
return index_source, mask_block_id
705+
706+
def handle_cartesian_tensor(
707+
position: int, index_elem: torch.Tensor, index_var: str
708+
) -> bool:
709+
nonlocal first_tensor_idx, output_idx, tensor_count, k_index
710+
711+
assert broadcast_shape is not None
712+
dims = len(broadcast_shape)
713+
is_cartesian = (
714+
dims >= 2
715+
and len(tensor_shapes) == dims
716+
and all(
717+
len(shape) == 1
718+
or sum(1 for dim in shape if env.size_hint(dim) != 1) <= 1
719+
for shape in tensor_shapes
720+
)
721+
)
722+
if not is_cartesian:
723+
return False
724+
725+
original_length = _get_padded_iota_original_length(state, position)
726+
if tensor_count == 0:
727+
first_tensor_idx = output_idx
728+
output_idx += dims
729+
730+
axis_pos = first_tensor_idx + tensor_count
731+
expand_axis = tile_strategy.expand_str(output_size, axis_pos)
732+
index_values.append(f"({index_var}){expand_axis}")
733+
734+
if original_length is not None:
735+
mask_values.setdefault(
736+
f"(({index_var} < {original_length}){expand_axis})"
737+
)
738+
739+
tensor_count += 1
740+
k_index += 1
741+
return True
742+
743+
def handle_broadcast_tensor(index_elem: torch.Tensor, index_var: str) -> bool:
744+
nonlocal first_tensor_idx, output_idx, tensor_count, k_index
745+
746+
assert broadcast_shape is not None
747+
if tensor_count == 0:
748+
first_tensor_idx = output_idx
749+
output_idx += len(broadcast_shape)
750+
751+
shape = (
752+
tensor_shapes[tensor_count]
753+
if tensor_count < len(tensor_shapes)
754+
else [1]
755+
)
756+
shape_size = len(shape)
757+
non_bcast_dims = sum(1 for dim in shape if env.size_hint(dim) != 1)
758+
is_single_dim = non_bcast_dims <= 1
759+
760+
offset = max(0, len(broadcast_shape) - shape_size)
761+
non_one_positions = [
762+
i for i, dim in enumerate(shape) if env.size_hint(dim) != 1
763+
]
764+
expand_pos = first_tensor_idx + offset
765+
if is_single_dim and shape_size > 0:
766+
rel_pos = non_one_positions[0] if non_one_positions else shape_size - 1
767+
expand_pos += rel_pos
768+
769+
if output_size:
770+
expand_pos = max(0, min(expand_pos, len(output_size) - 1))
771+
else:
772+
expand_pos = 0
773+
774+
# Number of dimensions to process for tensor indexing expansion
775+
width = (
776+
1
777+
if is_single_dim
778+
else min(shape_size, max(0, len(output_size) - expand_pos))
779+
)
780+
781+
if width <= 1:
782+
expand_pos_str = (
783+
tile_strategy.expand_str(output_size, expand_pos)
784+
if index_elem.ndim == 1
785+
else ""
786+
)
787+
index_source, mask_block_id = tensor_index_source_and_mask(
788+
index_elem,
789+
index_var,
790+
expand_pos,
791+
)
792+
expand = (
793+
expand_pos_str
794+
if expand_pos_str is not None
795+
else tile_strategy.expand_str(output_size, expand_pos)
796+
)
797+
index_values.append(f"({index_source}){expand}")
798+
if tensor_count == 0 and mask_block_id is not None:
799+
mask_var = state.codegen.mask_var(mask_block_id)
800+
if mask_var and not _is_size_one(
801+
fake_value.size(len(index_values) - 1)
802+
):
803+
mask_values.setdefault(f"({mask_var}){expand}")
804+
else:
805+
index_values.append(f"({index_var})")
806+
807+
if tensor_count == 0:
808+
for pos in [expand_pos + d for d in range(width)]:
809+
if pos >= len(output_size):
810+
continue
811+
block_idx = env.get_block_id(output_size[pos])
812+
if block_idx is None:
813+
continue
814+
expand_str = tile_strategy.expand_str(output_size, pos)
815+
mask_var = state.codegen.mask_var(block_idx)
816+
if mask_var and not _is_size_one(
817+
fake_value.size(len(index_values) - 1)
818+
):
819+
mask_values.setdefault(f"({mask_var}){expand_str}")
820+
821+
tensor_count += 1
822+
k_index += 1
823+
return True
824+
677825
for n, k in enumerate(index):
678826
if k is None:
679827
output_idx += 1
@@ -752,40 +900,36 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
752900
index_values.append(f"tl.zeros([1], {dtype}){expand}")
753901
output_idx += 1
754902
k_index += 1
755-
elif isinstance(k, torch.Tensor) and k.ndim == 1:
756-
expand = tile_strategy.expand_str(output_size, output_idx)
903+
elif isinstance(k, torch.Tensor):
757904
ast_index = state.ast_args[1]
758905
assert isinstance(ast_index, (list, tuple))
759-
assert len(ast_index) == len(index)
760906
index_var = state.codegen.lift(ast_index[n], prefix="index").id
761-
index_values.append(f"({index_var}){expand}")
762-
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
763-
if mask := state.codegen.mask_var(block_idx):
764-
mask_values.setdefault(f"({mask}){expand}")
765-
# Check if this index comes from a padded hl.arange and generate mask
766-
if (
767-
original_length := _get_padded_iota_original_length(state, n)
768-
) is not None:
769-
mask_values.setdefault(f"({index_var} < {original_length}){expand}")
770-
output_idx += 1
771-
k_index += 1
772-
elif (
773-
isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1
774-
):
775-
# TODO(jansel): combine this case with the above
776-
ast_index = state.ast_args[1]
777-
assert isinstance(ast_index, (list, tuple))
778-
assert len(ast_index) == 1
779-
index_var = state.codegen.lift(ast_index[0], prefix="index").id
780-
index_values.append(index_var)
781-
output_idx += k.ndim
782-
for n, s in enumerate(output_size):
783-
if (block_idx := env.get_block_id(s)) is not None and (
784-
mask := state.codegen.mask_var(block_idx)
907+
908+
if broadcast_shape is not None and (
909+
handle_cartesian_tensor(n, k, index_var)
910+
or handle_broadcast_tensor(k, index_var)
911+
):
912+
continue
913+
914+
index_source, mask_block_id = tensor_index_source_and_mask(
915+
k, index_var, output_idx
916+
)
917+
918+
expand = (
919+
tile_strategy.expand_str(output_size, output_idx)
920+
if k.ndim < len(output_size)
921+
else ""
922+
)
923+
index_values.append(f"({index_source}){expand}")
924+
if mask_block_id is not None:
925+
mask_var = state.codegen.mask_var(mask_block_id)
926+
if mask_var and not _is_size_one(
927+
fake_value.size(len(index_values) - 1)
785928
):
786-
mask_values.setdefault(
787-
f"({mask}){tile_strategy.expand_str(output_size, n)}"
788-
)
929+
mask_values.setdefault(f"({mask_var}){expand}")
930+
931+
output_idx += k.ndim
932+
tensor_count += 1
789933
k_index += 1
790934
else:
791935
raise exc.InvalidIndexingType(type(k))

0 commit comments

Comments
 (0)