Skip to content

Commit 3c4b6d1

Browse files
committed
fix
1 parent e695ab4 commit 3c4b6d1

File tree

8 files changed

+294
-56
lines changed

8 files changed

+294
-56
lines changed

helion/_compiler/compile_environment.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
self.specialized_vars: set[sympy.Symbol] = set()
113113
self.loop_dependency_checker = LoopDependencyChecker()
114114
self._symint_cache: dict[object, torch.SymInt] = {}
115+
self._tile_index_block_ids: dict[int, int] = {} # id(tensor) -> block_id
115116
self.device_load_count = (
116117
0 # Track number of loads in all device code for eviction policy tuning
117118
)
@@ -272,6 +273,65 @@ def cached_create_unbacked_symint(
272273
self._symint_cache[key] = result
273274
return result
274275

276+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
277+
"""Return the originating ``tile.index`` block id if present."""
278+
return self._tile_index_block_ids.get(tensor._helion_id) # type: ignore[attr-defined]
279+
280+
def tensor_indexer_broadcast_shape(
281+
self, tensors: typing.Sequence[torch.Tensor] | None
282+
) -> list[int | torch.SymInt] | None:
283+
"""Compute broadcast shape for tensor indexers, or None if not applicable."""
284+
tlist = [t for t in tensors or [] if isinstance(t, torch.Tensor)]
285+
if not tlist or all(self.get_tile_index_tensor_block_id(t) for t in tlist):
286+
return None
287+
shapes = [list(t.size()) for t in tlist]
288+
if all(len(s) == 1 for s in shapes) and len(shapes) > 1: # Cartesian
289+
return [s[0] for s in shapes]
290+
max_ndim = max(len(s) for s in shapes)
291+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
292+
return [
293+
next((d for d in dims if self.size_hint(d) != 1), 1)
294+
for dims in zip(*padded, strict=True)
295+
]
296+
297+
def tensor_indexer_dims(
298+
self, indexer_tensor: torch.Tensor, base_dim_size: int | torch.SymInt
299+
) -> list[int | torch.SymInt]:
300+
"""Return dims contributed by a tensor indexer (non-broadcast case)."""
301+
dims = list(indexer_tensor.size())
302+
non_bc = [d for d in dims if self.size_hint(d) != 1]
303+
if len(non_bc) > 1:
304+
return typing.cast("list[int | torch.SymInt]", dims)
305+
bid = (
306+
self.get_tile_index_tensor_block_id(indexer_tensor)
307+
or (self.get_block_id(base_dim_size) if base_dim_size else None)
308+
or (self.get_block_id(non_bc[0]) if non_bc else None)
309+
)
310+
return (
311+
[self.block_sizes[bid].var]
312+
if bid
313+
else (typing.cast("list[int | torch.SymInt]", non_bc) or [1])
314+
)
315+
316+
def new_index_result(
317+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
318+
) -> torch.Tensor:
319+
"""Create tensor for indexing ops, preserving tile index provenance."""
320+
shape = list(output_shape)
321+
non_bc = [i for i, s in enumerate(shape) if self.size_hint(s) != 1]
322+
bid = self.get_tile_index_tensor_block_id(tensor)
323+
if bid is None:
324+
bids = {self.get_block_id(shape[i]) for i in non_bc} - {None}
325+
bid = bids.pop() if len(bids) == 1 else None
326+
if bid and len(non_bc) == 1:
327+
shape[non_bc[0]] = self.block_sizes[bid].var
328+
elif len(non_bc) > 1:
329+
bid = None
330+
result = tensor.new_empty(shape)
331+
if bid is not None:
332+
self._tile_index_block_ids[result._helion_id] = bid # type: ignore[attr-defined]
333+
return result
334+
275335
def to_fake(self, obj: object, origin: Origin) -> object:
276336
if obj is None:
277337
return None

helion/_compiler/indexing_strategy.py

Lines changed: 140 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,13 @@ def compute_shape(
575575
input_size = collections.deque(tensor.size())
576576
output_size = []
577577
env = CompileEnvironment.current()
578+
579+
tensors = [k for k in index if isinstance(k, torch.Tensor)]
580+
broadcast_shape = env.tensor_indexer_broadcast_shape(tensors)
581+
first_broadcast_tensor_idx: int | None = None
582+
578583
k_index = 0
579-
for k in index:
584+
for position, k in enumerate(index):
580585
if k is None:
581586
output_size.append(1)
582587
elif isinstance(k, int):
@@ -617,11 +622,13 @@ def compute_shape(
617622
else:
618623
output_size.append(1)
619624
k_index += 1
620-
elif isinstance(k, torch.Tensor) and (
621-
k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1)
622-
):
623-
input_size.popleft()
624-
output_size.extend(k.size())
625+
elif isinstance(k, torch.Tensor):
626+
base_dim = input_size.popleft()
627+
if broadcast_shape is None:
628+
output_size.extend(env.tensor_indexer_dims(k, base_dim))
629+
elif first_broadcast_tensor_idx is None:
630+
output_size.extend(broadcast_shape)
631+
first_broadcast_tensor_idx = position
625632
k_index += 1
626633
else:
627634
raise exc.InvalidIndexingType(k)
@@ -667,13 +674,115 @@ def create(
667674
output_size = SubscriptIndexing.compute_shape(fake_value, index, state)
668675
env = CompileEnvironment.current()
669676
dtype = env.triton_index_type()
677+
all_tensors = [k for k in index if isinstance(k, torch.Tensor)]
678+
broadcast_shape = env.tensor_indexer_broadcast_shape(all_tensors)
679+
tensor_shapes = [list(t.size()) for t in all_tensors]
680+
first_tensor_idx = 0
681+
tensor_count = 0
670682
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
671683
raise exc.IndexOffsetOutOfRangeForInt32(env.index_dtype)
672684

673685
def _is_size_one(size: int | torch.SymInt) -> bool:
674686
return env.known_equal(size, 1)
675687

676688
k_index = 0
689+
690+
def tensor_index_source_and_mask(
691+
index_elem: torch.Tensor, index_var: str, pos: int
692+
) -> tuple[str, int | None]:
693+
tile_id = env.get_tile_index_tensor_block_id(index_elem)
694+
src = state.codegen.index_var(tile_id) if tile_id else index_var
695+
mask_id = tile_id or (
696+
env.get_block_id(output_size[pos]) if pos < len(output_size) else None
697+
)
698+
return src, mask_id
699+
700+
def handle_broadcast_tensor(
701+
position: int, index_elem: torch.Tensor, index_var: str
702+
) -> None:
703+
"""Handle tensor index with broadcast shape (cartesian or general)."""
704+
nonlocal first_tensor_idx, output_idx, tensor_count, k_index
705+
assert broadcast_shape is not None
706+
dims = len(broadcast_shape)
707+
if tensor_count == 0:
708+
first_tensor_idx = output_idx
709+
output_idx += dims
710+
711+
shape = (
712+
tensor_shapes[tensor_count]
713+
if tensor_count < len(tensor_shapes)
714+
else [1]
715+
)
716+
# Cartesian: multiple 1D tensors each contributing one dim
717+
is_cartesian = (
718+
dims >= 2
719+
and len(tensor_shapes) == dims
720+
and all(
721+
len(s) == 1 or sum(1 for d in s if env.size_hint(d) != 1) <= 1
722+
for s in tensor_shapes
723+
)
724+
)
725+
# Find position(s) where this tensor contributes
726+
offset = max(0, dims - len(shape))
727+
contrib = [
728+
first_tensor_idx + offset + i
729+
for i, d in enumerate(shape)
730+
if env.size_hint(d) != 1
731+
]
732+
pos = (
733+
first_tensor_idx + tensor_count
734+
if is_cartesian
735+
else (
736+
contrib[0]
737+
if contrib
738+
else max(
739+
0,
740+
min(
741+
first_tensor_idx + offset + len(shape) - 1,
742+
len(output_size) - 1,
743+
),
744+
)
745+
)
746+
)
747+
# Generate index expression
748+
if is_cartesian or len(contrib) <= 1:
749+
src, mask_id = tensor_index_source_and_mask(index_elem, index_var, pos)
750+
expand = (
751+
tile_strategy.expand_str(output_size, pos)
752+
if index_elem.ndim == 1
753+
else ""
754+
)
755+
index_values.append(f"({src}){expand}")
756+
if (
757+
tensor_count == 0
758+
and mask_id
759+
and (mv := state.codegen.mask_var(mask_id))
760+
):
761+
if not _is_size_one(fake_value.size(len(index_values) - 1)):
762+
mask_values.setdefault(f"({mv}){expand}")
763+
else:
764+
index_values.append(f"({index_var})")
765+
if tensor_count == 0:
766+
for p in contrib:
767+
if p < len(output_size) and (
768+
bid := env.get_block_id(output_size[p])
769+
):
770+
if (mv := state.codegen.mask_var(bid)) and not _is_size_one(
771+
fake_value.size(len(index_values) - 1)
772+
):
773+
mask_values.setdefault(
774+
f"({mv}){tile_strategy.expand_str(output_size, p)}"
775+
)
776+
# Padded iota mask
777+
if (
778+
orig_len := _get_padded_iota_original_length(state, position)
779+
) is not None:
780+
mask_values.setdefault(
781+
f"(({index_var} < {orig_len}){tile_strategy.expand_str(output_size, first_tensor_idx + tensor_count)})"
782+
)
783+
tensor_count += 1
784+
k_index += 1
785+
677786
for n, k in enumerate(index):
678787
if k is None:
679788
output_idx += 1
@@ -752,40 +861,35 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
752861
index_values.append(f"tl.zeros([1], {dtype}){expand}")
753862
output_idx += 1
754863
k_index += 1
755-
elif isinstance(k, torch.Tensor) and k.ndim == 1:
756-
expand = tile_strategy.expand_str(output_size, output_idx)
864+
elif isinstance(k, torch.Tensor):
757865
ast_index = state.ast_args[1]
758866
assert isinstance(ast_index, (list, tuple))
759-
assert len(ast_index) == len(index)
760867
index_var = state.codegen.lift(ast_index[n], prefix="index").id
761-
index_values.append(f"({index_var}){expand}")
762-
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
763-
if mask := state.codegen.mask_var(block_idx):
764-
mask_values.setdefault(f"({mask}){expand}")
765-
# Check if this index comes from a padded hl.arange and generate mask
766-
if (
767-
original_length := _get_padded_iota_original_length(state, n)
768-
) is not None:
769-
mask_values.setdefault(f"({index_var} < {original_length}){expand}")
770-
output_idx += 1
771-
k_index += 1
772-
elif (
773-
isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1
774-
):
775-
# TODO(jansel): combine this case with the above
776-
ast_index = state.ast_args[1]
777-
assert isinstance(ast_index, (list, tuple))
778-
assert len(ast_index) == 1
779-
index_var = state.codegen.lift(ast_index[0], prefix="index").id
780-
index_values.append(index_var)
781-
output_idx += k.ndim
782-
for n, s in enumerate(output_size):
783-
if (block_idx := env.get_block_id(s)) is not None and (
784-
mask := state.codegen.mask_var(block_idx)
868+
869+
# Use broadcast handling for: multiple tensors, or single tensor with ndim > 1
870+
if broadcast_shape is not None and (len(all_tensors) > 1 or k.ndim > 1):
871+
handle_broadcast_tensor(n, k, index_var)
872+
continue
873+
874+
index_source, mask_block_id = tensor_index_source_and_mask(
875+
k, index_var, output_idx
876+
)
877+
878+
expand = (
879+
tile_strategy.expand_str(output_size, output_idx)
880+
if k.ndim < len(output_size)
881+
else ""
882+
)
883+
index_values.append(f"({index_source}){expand}")
884+
if mask_block_id is not None:
885+
mask_var = state.codegen.mask_var(mask_block_id)
886+
if mask_var and not _is_size_one(
887+
fake_value.size(len(index_values) - 1)
785888
):
786-
mask_values.setdefault(
787-
f"({mask}){tile_strategy.expand_str(output_size, n)}"
788-
)
889+
mask_values.setdefault(f"({mask_var}){expand}")
890+
891+
output_idx += k.ndim
892+
tensor_count += 1
789893
k_index += 1
790894
else:
791895
raise exc.InvalidIndexingType(type(k))

helion/_compiler/type_propagation.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,10 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
460460
inputs_consumed = 0
461461
output_sizes = []
462462
env = CompileEnvironment.current()
463-
for k in keys:
463+
tensor_indexers = [k.fake_value for k in keys if isinstance(k, TensorType)]
464+
broadcast_shape = env.tensor_indexer_broadcast_shape(tensor_indexers)
465+
first_broadcast_tensor_idx: int | None = None
466+
for position, k in enumerate(keys):
464467
if isinstance(k, LiteralType):
465468
if isinstance(k.value, (int, torch.SymInt)):
466469
inputs_consumed += 1
@@ -505,9 +508,19 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
505508
raise exc.DataDependentOutputShapeNotSupported(
506509
op_desc="Boolean mask indexing (tensor[boolean_mask])"
507510
)
508-
elif isinstance(k, TensorType) and k.fake_value.ndim == 1:
511+
elif isinstance(k, TensorType):
512+
base_dim_size = self.fake_value.size(inputs_consumed)
509513
inputs_consumed += 1
510-
output_sizes.append(k.fake_value.size(0))
514+
if broadcast_shape is None:
515+
output_sizes.extend(
516+
env.tensor_indexer_dims(
517+
k.fake_value,
518+
base_dim_size,
519+
)
520+
)
521+
elif first_broadcast_tensor_idx is None:
522+
output_sizes.extend(broadcast_shape)
523+
first_broadcast_tensor_idx = position
511524
elif k.contains_type(TileIndexType):
512525
raise exc.OverpackedTile(k)
513526
else:
@@ -553,9 +566,11 @@ def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
553566
raise exc.TypeInferenceError(
554567
f"Subscript not supported on {self!s} with key={key!s}"
555568
) from None
556-
return TensorType(
557-
origin, self.fake_value.new_empty(self._device_indexing_size(key))
558-
)
569+
new_sizes = self._device_indexing_size(key)
570+
env = CompileEnvironment.current()
571+
new_fake = env.new_index_result(self.fake_value, new_sizes)
572+
573+
return TensorType(origin, new_fake)
559574

560575
def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo:
561576
if isinstance(other, TensorType):
@@ -2143,8 +2158,31 @@ def visit_NamedExpr(self, node: ast.NamedExpr) -> TypeInfo:
21432158
return type_info
21442159

21452160
def visit_Subscript(self, node: ast.Subscript) -> TypeInfo:
2146-
value_type = self.visit(node.value)
2147-
slice_type = self.visit(node.slice)
2161+
value_type, slice_type = self.visit(node.value), self.visit(node.slice)
2162+
# In device loops, check for overpacked tiles and rank mismatch
2163+
if self.device_loop_depth > 0 and isinstance(value_type, TensorType):
2164+
keys = (
2165+
slice_type.unpack()
2166+
if isinstance(slice_type, SequenceType)
2167+
else [slice_type]
2168+
)
2169+
consumed, has_tensor = 0, False
2170+
for k in keys:
2171+
if k.contains_type(TileIndexType) and not isinstance(k, TileIndexType):
2172+
raise exc.OverpackedTile(k)
2173+
if isinstance(k, TensorType):
2174+
has_tensor, consumed = True, consumed + 1
2175+
elif isinstance(k, SliceType) or (
2176+
isinstance(k, (LiteralType, SymIntType, TileIndexType))
2177+
and not (isinstance(k, LiteralType) and k.value is None)
2178+
):
2179+
consumed += 1
2180+
if not has_tensor and consumed < value_type.fake_value.ndim:
2181+
raise exc.RankMismatch(
2182+
value_type.fake_value.ndim,
2183+
consumed,
2184+
f"tensor shape: {tuple(value_type.fake_value.shape)}, indexed {consumed} dimensions",
2185+
)
21482186
return value_type.propagate_getitem(slice_type, self.origin())
21492187

21502188
def visit_Slice(self, node: ast.Slice) -> TypeInfo:

0 commit comments

Comments
 (0)