Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Sep 12, 2025

Fixes #546.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 12, 2025
@yf225 yf225 force-pushed the indirect_indexing_v2 branch 26 times, most recently from 28424b9 to 515ad61 Compare September 16, 2025 18:06
@yf225 yf225 force-pushed the indirect_indexing_v2 branch 3 times, most recently from 252f6e0 to 37daf60 Compare November 14, 2025 04:07
@yf225 yf225 force-pushed the indirect_indexing_v2 branch 14 times, most recently from 7d9f694 to 9ded093 Compare November 22, 2025 20:24
@yf225 yf225 changed the title [WIP] Add 2d and 3d indirect indexing support Add 2d and 3d indirect indexing support Nov 22, 2025
@yf225 yf225 requested review from jansel and oulgen November 22, 2025 21:03
@yf225 yf225 marked this pull request as ready for review November 22, 2025 21:03
@yf225 yf225 force-pushed the indirect_indexing_v2 branch from 9ded093 to 312a854 Compare November 22, 2025 21:08
@yf225 yf225 force-pushed the indirect_indexing_v2 branch from 312a854 to 9029b5a Compare November 24, 2025 20:50

def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
"""Return the originating ``tile.index`` block id if present."""
return getattr(tensor, "_tile_index_block_id", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on here? We should either have an attribute or not -- this getattr stuff is hacky. I think there should be a cleaner way to get this.

output_size.extend(k.size())
elif isinstance(k, torch.Tensor):
base_dim = input_size.popleft()
if broadcast_shape is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need two cases here?

Comment on lines 689 to 785

def tensor_index_source_and_mask(
index_elem: torch.Tensor,
index_var: str,
pos: int,
) -> tuple[str, int | None]:
tile_block_id = env.get_tile_index_tensor_block_id(index_elem)
index_source = (
state.codegen.index_var(tile_block_id)
if tile_block_id is not None
else index_var
)
mask_block_id = tile_block_id or (
env.get_block_id(output_size[pos]) if pos < len(output_size) else None
)
return index_source, mask_block_id

def handle_cartesian_tensor(
position: int, index_elem: torch.Tensor, index_var: str
) -> bool:
nonlocal first_tensor_idx, output_idx, tensor_count, k_index

assert broadcast_shape is not None
dims = len(broadcast_shape)
is_cartesian = (
dims >= 2
and len(tensor_shapes) == dims
and all(
len(shape) == 1
or sum(1 for dim in shape if env.size_hint(dim) != 1) <= 1
for shape in tensor_shapes
)
)
if not is_cartesian:
return False

original_length = _get_padded_iota_original_length(state, position)
if tensor_count == 0:
first_tensor_idx = output_idx
output_idx += dims

axis_pos = first_tensor_idx + tensor_count
expand_axis = tile_strategy.expand_str(output_size, axis_pos)
index_values.append(f"({index_var}){expand_axis}")

if original_length is not None:
mask_values.setdefault(
f"(({index_var} < {original_length}){expand_axis})"
)

tensor_count += 1
k_index += 1
return True

def handle_broadcast_tensor(index_elem: torch.Tensor, index_var: str) -> bool:
nonlocal first_tensor_idx, output_idx, tensor_count, k_index

assert broadcast_shape is not None
if tensor_count == 0:
first_tensor_idx = output_idx
output_idx += len(broadcast_shape)

shape = (
tensor_shapes[tensor_count]
if tensor_count < len(tensor_shapes)
else [1]
)
shape_size = len(shape)
non_bcast_dims = sum(1 for dim in shape if env.size_hint(dim) != 1)
is_single_dim = non_bcast_dims <= 1

offset = max(0, len(broadcast_shape) - shape_size)
non_one_positions = [
i for i, dim in enumerate(shape) if env.size_hint(dim) != 1
]
expand_pos = first_tensor_idx + offset
if is_single_dim and shape_size > 0:
rel_pos = non_one_positions[0] if non_one_positions else shape_size - 1
expand_pos += rel_pos

if output_size:
expand_pos = max(0, min(expand_pos, len(output_size) - 1))
else:
expand_pos = 0

# Number of dimensions to process for tensor indexing expansion
width = (
1
if is_single_dim
else min(shape_size, max(0, len(output_size) - expand_pos))
)

if width <= 1:
expand_pos_str = (
tile_strategy.expand_str(output_size, expand_pos)
if index_elem.ndim == 1
else ""
)
index_source, mask_block_id = tensor_index_source_and_mask(
index_elem,
index_var,
expand_pos,
)
expand = (
expand_pos_str
if expand_pos_str is not None
else tile_strategy.expand_str(output_size, expand_pos)
)
index_values.append(f"({index_source}){expand}")
if tensor_count == 0 and mask_block_id is not None:
mask_var = state.codegen.mask_var(mask_block_id)
if mask_var and not _is_size_one(
fake_value.size(len(index_values) - 1)
):
mask_values.setdefault(f"({mask_var}){expand}")
else:
index_values.append(f"({index_var})")

if tensor_count == 0:
for pos in [expand_pos + d for d in range(width)]:
if pos >= len(output_size):
continue
block_idx = env.get_block_id(output_size[pos])
if block_idx is None:
continue
expand_str = tile_strategy.expand_str(output_size, pos)
mask_var = state.codegen.mask_var(block_idx)
if mask_var and not _is_size_one(
fake_value.size(len(index_values) - 1)
):
mask_values.setdefault(f"({mask_var}){expand_str}")

tensor_count += 1
k_index += 1
return True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here complicate this function a lot. Is there a reason it needs to be so complex? Are there really this many different cases?

@yf225 yf225 force-pushed the indirect_indexing_v2 branch 4 times, most recently from 4616898 to 87fd131 Compare November 26, 2025 06:26
@yf225 yf225 changed the title Add 2d and 3d indirect indexing support [WIP] Add 2d and 3d indirect indexing support Nov 26, 2025
@yf225 yf225 force-pushed the indirect_indexing_v2 branch from 87fd131 to 3c4b6d1 Compare November 26, 2025 08:00
@yf225 yf225 force-pushed the indirect_indexing_v2 branch from 3c4b6d1 to fc9a862 Compare November 26, 2025 08:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Helion indirect indexing with higher dimensions

3 participants