Skip to content

Commit 515ad61

Browse files
committed
initial version
1 parent 5c71db4 commit 515ad61

File tree

8 files changed

+403
-74
lines changed

8 files changed

+403
-74
lines changed

helion/_compiler/ast_extension.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import TYPE_CHECKING
99
from typing import TypeVar
1010

11+
import torch
12+
1113
from .. import exc
1214
from .source_location import SourceLocation
1315
from .source_location import current_location
@@ -82,10 +84,29 @@ def __repr__(self) -> str:
8284

8385
def update_type_info(self, type_info: TypeInfo) -> TypeInfo:
8486
if self._type_info is not None and type_info != self._type_info:
87+
prev_rank = self._tensor_rank(self._type_info)
88+
new_rank = self._tensor_rank(type_info)
89+
if (
90+
prev_rank is not None
91+
and new_rank is not None
92+
and prev_rank != new_rank
93+
):
94+
self._type_info = type_info
95+
return self._type_info
8596
type_info = self._type_info.merge(type_info)
8697
self._type_info = type_info
8798
return self._type_info
8899

100+
@staticmethod
101+
def _tensor_rank(type_info: "TypeInfo") -> int | None:
102+
for attr in ["fake_value", "tensor"]:
103+
obj = getattr(type_info, attr, None)
104+
if attr == "tensor" and obj is not None:
105+
obj = getattr(obj, "fake_value", None)
106+
if isinstance(obj, torch.Tensor):
107+
return obj.dim()
108+
return None
109+
89110
def debug_annotations(self) -> list[str]:
90111
result = []
91112
if self._type_info:

helion/_compiler/compile_environment.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
142142
if rdim.reduction and rdim.size == size:
143143
return rdim
144144

145+
# Check if size matches any tile dimension for symbolic equality.
146+
# When building expressions that mix sizes derived from tiles (e.g. via
147+
# slicing) with sizes coming directly from tile block vars, we want them
148+
# to share the same SymInt variable whenever they are equal by
149+
# construction. This preserves equality in the shape environment and
150+
# avoids spurious "size mismatch" issues during fake-tensor broadcasting
151+
# and arithmetic in type propagation.
152+
if isinstance(size, torch.SymInt):
153+
block_idx = self.get_block_id(size)
154+
if block_idx is not None and not self.block_sizes[block_idx].reduction:
155+
return self._clone_block_size_as_reduction(block_idx, size)
156+
157+
sym = size._sympy_()
158+
for block_idx, block_info in enumerate(self.block_sizes):
159+
if not block_info.reduction and sym == block_info.symbol():
160+
return self._clone_block_size_as_reduction(block_idx, size)
161+
145162
# Allocate a new reduction dimension
163+
return self._allocate_new_reduction(size)
164+
165+
def _clone_block_size_as_reduction(
166+
self, block_idx: int, size: torch.SymInt | int
167+
) -> BlockSizeInfo:
168+
rdim = self._allocate_new_reduction(size)
169+
rdim.var = self.block_sizes[block_idx].var
170+
return rdim
171+
172+
def _allocate_new_reduction(self, size: torch.SymInt | int) -> BlockSizeInfo:
146173
rdim_idx = self.allocate_block_size(
147174
size,
148175
reduction=True,
149176
source=ReductionLoopBlockSizeSource(
150-
sum([int(bs.reduction) for bs in self.block_sizes])
177+
self._next_reduction_loop_index()
151178
),
152179
hint=next_power_of_2(self.size_hint(size)),
153180
)
154181
return self.block_sizes[rdim_idx]
155182

183+
def _next_reduction_loop_index(self) -> int:
184+
return sum(int(info.reduction) for info in self.block_sizes)
185+
156186
def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
157187
with self.shape_env.ignore_fresh_unbacked_symbols():
158188
sym = self.shape_env.create_unbacked_symint()
@@ -203,6 +233,90 @@ def cached_create_unbacked_symint(
203233
self._symint_cache[key] = result
204234
return result
205235

236+
237+
def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
238+
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
239+
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]
240+
241+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
242+
"""Return the originating ``tile.index`` block id if present."""
243+
return getattr(tensor, "_tile_index_block_id", None)
244+
245+
def get_indexer_output_dims(
246+
self,
247+
indexer_tensor: torch.Tensor,
248+
base_dim_size: int | torch.SymInt | None,
249+
) -> list[int | torch.SymInt]:
250+
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""
251+
252+
dims = list(indexer_tensor.size())
253+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
254+
255+
# Multi-dimensional indexer - return full shape
256+
if len(non_broadcast_dims) > 1:
257+
return dims
258+
259+
# Try to find block_id from various sources
260+
block_id = (
261+
self.get_tile_index_tensor_block_id(indexer_tensor)
262+
or (self.get_block_id(base_dim_size) if base_dim_size is not None else None)
263+
or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None)
264+
)
265+
266+
if block_id is not None:
267+
return [self.block_sizes[block_id].var]
268+
return [non_broadcast_dims[0]] if non_broadcast_dims else [1]
269+
270+
def tensor_indexer_broadcast_shape(
271+
self, tensors: typing.Sequence[torch.Tensor]
272+
) -> list[int | torch.SymInt] | None:
273+
"""Compute a shared broadcast shape for tensor indexers when needed."""
274+
275+
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
276+
if not tensor_list:
277+
return None
278+
279+
if all(self.get_tile_index_tensor_block_id(t) is not None for t in tensor_list):
280+
return None
281+
282+
shapes = [list(t.size()) for t in tensor_list]
283+
return compute_broadcast_shape_for_tensor_indexers(shapes, self)
284+
285+
def resolve_tile_index_shape(
286+
self, input_tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
287+
) -> tuple[list[int | torch.SymInt], int | None]:
288+
"""Resolve the symbolic shape for tensors derived from ``tile.index``.
289+
290+
Returns a copy of ``output_shape`` where the single non-broadcast
291+
dimension is replaced with the canonical block-symbol and the associated
292+
block_id to register on the new tensor. If the tensor is not a tile
293+
indexer or it introduces more than one non-broadcast dimension, the
294+
original shape and ``None`` are returned.
295+
"""
296+
297+
block_id = self.get_tile_index_tensor_block_id(input_tensor)
298+
if block_id is None:
299+
return list(output_shape), None
300+
301+
resolved = list(output_shape)
302+
non_broadcast = [i for i, s in enumerate(resolved) if self.size_hint(s) != 1]
303+
if len(non_broadcast) <= 1:
304+
if non_broadcast:
305+
resolved[non_broadcast[0]] = self.block_sizes[block_id].var
306+
return resolved, block_id
307+
return resolved, None
308+
309+
def new_index_result(
310+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
311+
) -> torch.Tensor:
312+
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""
313+
314+
resolved_shape, block_id = self.resolve_tile_index_shape(tensor, output_shape)
315+
result = tensor.new_empty(resolved_shape)
316+
if block_id is not None:
317+
self.register_tile_index_tensor_block_id(result, block_id)
318+
return result
319+
206320
def to_fake(self, obj: object, origin: Origin) -> object:
207321
if isinstance(obj, torch.Tensor):
208322
return self._to_fake_tensor(obj, origin.to_source())
@@ -283,6 +397,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
283397
self.fake_mode, tensor, shape_env=self.shape_env, source=source
284398
)
285399
self.input_sources[result] = source
400+
if hasattr(tensor, "_tile_index_block_id"):
401+
self.register_tile_index_tensor_block_id(
402+
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
403+
)
286404
if isinstance(source, LocalSource):
287405
for i, s in enumerate(result.size()):
288406
if isinstance(s, torch.SymInt) and isinstance(
@@ -535,3 +653,20 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
535653

536654
def _has_unbacked(expr: sympy.Expr) -> bool:
537655
return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue]
656+
657+
658+
def compute_broadcast_shape_for_tensor_indexers(
659+
shapes: list[list[int | torch.SymInt]],
660+
env: "CompileEnvironment"
661+
) -> list[int | torch.SymInt]:
662+
"""Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting."""
663+
if not shapes:
664+
return []
665+
666+
max_ndim = max(len(s) for s in shapes)
667+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
668+
669+
return [
670+
next((d for d in dims if env.size_hint(d) != 1), 1)
671+
for dims in zip(*padded, strict=True)
672+
]

0 commit comments

Comments
 (0)