-
Notifications
You must be signed in to change notification settings - Fork 74
[WIP] Add 2d and 3d indirect indexing support #593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
28424b9 to
515ad61
Compare
252f6e0 to
37daf60
Compare
7d9f694 to
9ded093
Compare
9ded093 to
312a854
Compare
312a854 to
9029b5a
Compare
|
|
||
| def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None: | ||
| """Return the originating ``tile.index`` block id if present.""" | ||
| return getattr(tensor, "_tile_index_block_id", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is going on here? We should either have an attribute or not -- this getattr stuff is hacky. I think there should be a cleaner way to get this.
| output_size.extend(k.size()) | ||
| elif isinstance(k, torch.Tensor): | ||
| base_dim = input_size.popleft() | ||
| if broadcast_shape is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need two cases here?
|
|
||
| def tensor_index_source_and_mask( | ||
| index_elem: torch.Tensor, | ||
| index_var: str, | ||
| pos: int, | ||
| ) -> tuple[str, int | None]: | ||
| tile_block_id = env.get_tile_index_tensor_block_id(index_elem) | ||
| index_source = ( | ||
| state.codegen.index_var(tile_block_id) | ||
| if tile_block_id is not None | ||
| else index_var | ||
| ) | ||
| mask_block_id = tile_block_id or ( | ||
| env.get_block_id(output_size[pos]) if pos < len(output_size) else None | ||
| ) | ||
| return index_source, mask_block_id | ||
|
|
||
| def handle_cartesian_tensor( | ||
| position: int, index_elem: torch.Tensor, index_var: str | ||
| ) -> bool: | ||
| nonlocal first_tensor_idx, output_idx, tensor_count, k_index | ||
|
|
||
| assert broadcast_shape is not None | ||
| dims = len(broadcast_shape) | ||
| is_cartesian = ( | ||
| dims >= 2 | ||
| and len(tensor_shapes) == dims | ||
| and all( | ||
| len(shape) == 1 | ||
| or sum(1 for dim in shape if env.size_hint(dim) != 1) <= 1 | ||
| for shape in tensor_shapes | ||
| ) | ||
| ) | ||
| if not is_cartesian: | ||
| return False | ||
|
|
||
| original_length = _get_padded_iota_original_length(state, position) | ||
| if tensor_count == 0: | ||
| first_tensor_idx = output_idx | ||
| output_idx += dims | ||
|
|
||
| axis_pos = first_tensor_idx + tensor_count | ||
| expand_axis = tile_strategy.expand_str(output_size, axis_pos) | ||
| index_values.append(f"({index_var}){expand_axis}") | ||
|
|
||
| if original_length is not None: | ||
| mask_values.setdefault( | ||
| f"(({index_var} < {original_length}){expand_axis})" | ||
| ) | ||
|
|
||
| tensor_count += 1 | ||
| k_index += 1 | ||
| return True | ||
|
|
||
| def handle_broadcast_tensor(index_elem: torch.Tensor, index_var: str) -> bool: | ||
| nonlocal first_tensor_idx, output_idx, tensor_count, k_index | ||
|
|
||
| assert broadcast_shape is not None | ||
| if tensor_count == 0: | ||
| first_tensor_idx = output_idx | ||
| output_idx += len(broadcast_shape) | ||
|
|
||
| shape = ( | ||
| tensor_shapes[tensor_count] | ||
| if tensor_count < len(tensor_shapes) | ||
| else [1] | ||
| ) | ||
| shape_size = len(shape) | ||
| non_bcast_dims = sum(1 for dim in shape if env.size_hint(dim) != 1) | ||
| is_single_dim = non_bcast_dims <= 1 | ||
|
|
||
| offset = max(0, len(broadcast_shape) - shape_size) | ||
| non_one_positions = [ | ||
| i for i, dim in enumerate(shape) if env.size_hint(dim) != 1 | ||
| ] | ||
| expand_pos = first_tensor_idx + offset | ||
| if is_single_dim and shape_size > 0: | ||
| rel_pos = non_one_positions[0] if non_one_positions else shape_size - 1 | ||
| expand_pos += rel_pos | ||
|
|
||
| if output_size: | ||
| expand_pos = max(0, min(expand_pos, len(output_size) - 1)) | ||
| else: | ||
| expand_pos = 0 | ||
|
|
||
| # Number of dimensions to process for tensor indexing expansion | ||
| width = ( | ||
| 1 | ||
| if is_single_dim | ||
| else min(shape_size, max(0, len(output_size) - expand_pos)) | ||
| ) | ||
|
|
||
| if width <= 1: | ||
| expand_pos_str = ( | ||
| tile_strategy.expand_str(output_size, expand_pos) | ||
| if index_elem.ndim == 1 | ||
| else "" | ||
| ) | ||
| index_source, mask_block_id = tensor_index_source_and_mask( | ||
| index_elem, | ||
| index_var, | ||
| expand_pos, | ||
| ) | ||
| expand = ( | ||
| expand_pos_str | ||
| if expand_pos_str is not None | ||
| else tile_strategy.expand_str(output_size, expand_pos) | ||
| ) | ||
| index_values.append(f"({index_source}){expand}") | ||
| if tensor_count == 0 and mask_block_id is not None: | ||
| mask_var = state.codegen.mask_var(mask_block_id) | ||
| if mask_var and not _is_size_one( | ||
| fake_value.size(len(index_values) - 1) | ||
| ): | ||
| mask_values.setdefault(f"({mask_var}){expand}") | ||
| else: | ||
| index_values.append(f"({index_var})") | ||
|
|
||
| if tensor_count == 0: | ||
| for pos in [expand_pos + d for d in range(width)]: | ||
| if pos >= len(output_size): | ||
| continue | ||
| block_idx = env.get_block_id(output_size[pos]) | ||
| if block_idx is None: | ||
| continue | ||
| expand_str = tile_strategy.expand_str(output_size, pos) | ||
| mask_var = state.codegen.mask_var(block_idx) | ||
| if mask_var and not _is_size_one( | ||
| fake_value.size(len(index_values) - 1) | ||
| ): | ||
| mask_values.setdefault(f"({mask_var}){expand_str}") | ||
|
|
||
| tensor_count += 1 | ||
| k_index += 1 | ||
| return True | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes here complicate this function a lot. Is there a reason it needs to be so complex? Are there really this many different cases?
4616898 to
87fd131
Compare
87fd131 to
3c4b6d1
Compare
3c4b6d1 to
fc9a862
Compare
Fixes #546.