|
3 | 3 | import ast
|
4 | 4 | import builtins
|
5 | 5 | import inspect
|
| 6 | +import itertools |
6 | 7 | from itertools import starmap
|
7 | 8 | from typing import TYPE_CHECKING
|
8 | 9 | from typing import Iterator
|
@@ -449,6 +450,81 @@ def _(state: CodegenState) -> ast.AST:
|
449 | 450 | return _codegen_loop_helper(state)
|
450 | 451 |
|
451 | 452 |
|
| 453 | +@_decorators.ref(tile) |
| 454 | +def _( |
| 455 | + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], |
| 456 | + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, |
| 457 | + block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None, |
| 458 | +) -> Iterator[slice | tuple[slice, ...]]: |
| 459 | + # Convert tensor values to int |
| 460 | + def _to_int(value): |
| 461 | + if value is None: |
| 462 | + return None |
| 463 | + if isinstance(value, torch.Tensor): |
| 464 | + return int(value.item()) |
| 465 | + return int(value) |
| 466 | + |
| 467 | + # Step 1: Normalize begin and end values based on the number of arguments |
| 468 | + if end_or_none is not None: |
| 469 | + # Two positional args: begin_or_end is begin, end_or_none is end |
| 470 | + begin = begin_or_end |
| 471 | + end = end_or_none |
| 472 | + else: |
| 473 | + # One positional arg: begin_or_end is end, begin defaults to 0 |
| 474 | + end = begin_or_end |
| 475 | + # Create begin with same structure as end, but all zeros |
| 476 | + if isinstance(end, (list, tuple)): |
| 477 | + begin = [0] * len(end) |
| 478 | + else: |
| 479 | + begin = 0 |
| 480 | + |
| 481 | + # Step 2: Convert inputs to lists for uniform handling |
| 482 | + def _normalize_to_list( |
| 483 | + value: int | torch.Tensor | list[int | torch.Tensor], |
| 484 | + ) -> list[int | torch.Tensor]: |
| 485 | + if isinstance(value, (list, tuple)): |
| 486 | + return list(value) |
| 487 | + return [value] |
| 488 | + |
| 489 | + begin_list = _normalize_to_list(begin) |
| 490 | + end_list = _normalize_to_list(end) |
| 491 | + |
| 492 | + # Convert all values to int |
| 493 | + begin_list = [_to_int(b) for b in begin_list] |
| 494 | + end_list = [_to_int(e) for e in end_list] |
| 495 | + |
| 496 | + # Step 3: Determine block_size based on the arguments |
| 497 | + if block_size is None: |
| 498 | + # Default block_size to end - begin for each dimension |
| 499 | + block_size_list = [e - b for b, e in zip(begin_list, end_list, strict=False)] |
| 500 | + else: |
| 501 | + block_size_list = _normalize_to_list(block_size) |
| 502 | + block_size_list = [ |
| 503 | + _to_int(bs) if bs is not None else (e - b) |
| 504 | + for bs, b, e in zip(block_size_list, begin_list, end_list, strict=False) |
| 505 | + ] |
| 506 | + |
| 507 | + # Step 4: Yield tile ranges |
| 508 | + # Handle single dimension case |
| 509 | + if len(begin_list) == 1: |
| 510 | + b = begin_list[0] |
| 511 | + e = end_list[0] |
| 512 | + bs = block_size_list[0] |
| 513 | + for i in range(b, e, bs): |
| 514 | + yield slice(i, min(i + bs, e)) |
| 515 | + else: |
| 516 | + # Handle multi-dimensional case |
| 517 | + ranges = [] |
| 518 | + for b, e, bs in zip(begin_list, end_list, block_size_list, strict=False): |
| 519 | + dim_ranges = [] |
| 520 | + for i in range(b, e, bs): |
| 521 | + dim_ranges.append(slice(i, min(i + bs, e))) |
| 522 | + ranges.append(dim_ranges) |
| 523 | + |
| 524 | + for combo in itertools.product(*ranges): |
| 525 | + yield combo |
| 526 | + |
| 527 | + |
452 | 528 | def _codegen_loop_helper(
|
453 | 529 | state: CodegenState,
|
454 | 530 | ) -> ast.AST:
|
@@ -637,6 +713,32 @@ def _(state: CodegenState) -> ast.AST:
|
637 | 713 | return _codegen_loop_helper(state)
|
638 | 714 |
|
639 | 715 |
|
| 716 | +@_decorators.ref(grid) |
| 717 | +def _( |
| 718 | + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], |
| 719 | + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, |
| 720 | + step: object = None, |
| 721 | +) -> range | Iterator[tuple[int, ...]]: |
| 722 | + # Similar to tile but yields indices instead of slices |
| 723 | + if end_or_none is not None: |
| 724 | + begin = begin_or_end |
| 725 | + end = end_or_none |
| 726 | + else: |
| 727 | + end = begin_or_end |
| 728 | + if isinstance(end, (list, tuple)): |
| 729 | + begin = [0] * len(end) |
| 730 | + else: |
| 731 | + begin = 0 |
| 732 | + |
| 733 | + # Handle single dimension |
| 734 | + if not isinstance(begin, (list, tuple)): |
| 735 | + return range(begin, end) |
| 736 | + |
| 737 | + # Handle multi-dimensional |
| 738 | + ranges = list(itertools.starmap(range, zip(begin, end, strict=False))) |
| 739 | + return itertools.product(*ranges) |
| 740 | + |
| 741 | + |
640 | 742 | @_decorators.device_func_replacement(builtins.zip)
|
641 | 743 | @_decorators.api(is_device_only=True, cache_type=True)
|
642 | 744 | def _zip_replacement(
|
@@ -898,3 +1000,14 @@ def _(
|
898 | 1000 |
|
899 | 1001 | # Return tuple(range(...)) which will trigger existing tuple/list unrolling
|
900 | 1002 | return tuple(range(begin_val, end_val, step))
|
| 1003 | + |
| 1004 | + |
| 1005 | +@_decorators.ref(static_range) |
| 1006 | +def _( |
| 1007 | + begin_or_end: int, |
| 1008 | + end_or_none: int | None = None, |
| 1009 | + step: int = 1, |
| 1010 | +) -> range: |
| 1011 | + if end_or_none is not None: |
| 1012 | + return range(begin_or_end, end_or_none, step) |
| 1013 | + return range(begin_or_end) |
0 commit comments