Skip to content

Commit aa37807

Browse files
committed
Address review comments
1 parent f6055e2 commit aa37807

File tree

14 files changed

+155
-121
lines changed

14 files changed

+155
-121
lines changed

tripy/tests/flat_ir/ops/test_broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@ def test_str(self):
3333
broadcast = func_broadcast.ops[-1]
3434
assert isinstance(broadcast, DynamicBroadcastOp)
3535
assert re.match(
36-
r"t_inter3: \[rank=\(2\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter4, t_inter2, broadcast_dim=\[\]\)",
36+
r"t_inter[0-9]+: \[rank=\(2\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[\]\)",
3737
str(broadcast),
3838
)

tripy/tests/flat_ir/ops/test_gather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tripy as tp
2020

2121
from tripy.flat_ir.ops import DynamicGatherOp
22-
from tripy.flat_ir.ops.base import FlatIRFunction
22+
from tripy.flat_ir.function import FlatIRFunction
2323
from tripy.frontend.trace import Trace
2424
import re
2525

tripy/tests/flat_ir/ops/test_maximum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import re
1919
import tripy as tp
2020
from tripy.frontend.trace import Trace
21+
from tripy.flat_ir.function import FlatIRFunction
2122
from tripy.flat_ir.ops import MaxOp
22-
from tripy.flat_ir.ops.base import FlatIRFunction
2323

2424

2525
class TestMaxOp:

tripy/tests/flat_ir/ops/test_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
import tripy as tp
1919
from tripy.frontend.trace import Trace
20+
from tripy.flat_ir.function import FlatIRFunction
2021
from tripy.flat_ir.ops import ArgMinMaxOp, ConvertOp, DivideOp, DynamicBroadcastOp, MulOp, ReduceOp
21-
from tripy.flat_ir.ops.base import FlatIRFunction
2222
import re
2323

2424

tripy/tests/flat_ir/ops/test_subtract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import re
1919
import tripy as tp
2020
from tripy.frontend.trace import Trace
21+
from tripy.flat_ir.function import FlatIRFunction
2122
from tripy.flat_ir.ops import SubtractOp
22-
from tripy.flat_ir.ops.base import FlatIRFunction
2323

2424

2525
class TestSubtractOp:

tripy/tests/flat_ir/test_constant_deduplication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import pytest
1717
from tripy.flat_ir.flat_ir import FlatIR
18-
from tripy.flat_ir.ops.base import FlatIRFunction
18+
from tripy.flat_ir.function import FlatIRFunction
1919
from tripy.flat_ir.ops import ConstantOp
2020
from tripy.flat_ir.tensor import FlatIRTensor
2121
from tripy.common.device import device

tripy/tests/flat_ir/test_function_deduplication.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from typing import List, Optional
1919

2020
from tripy.flat_ir.flat_ir import FlatIR
21-
from tripy.flat_ir.ops.base import FlatIRFunction, BaseFlatIROp
21+
from tripy.flat_ir.function import FlatIRFunction
22+
from tripy.flat_ir.ops.base import BaseFlatIROp
2223
from tripy.flat_ir.ops import ConstantOp
2324
from tripy.flat_ir.tensor import FlatIRTensor
2425
from tripy.common.device import device
@@ -39,7 +40,7 @@ def __eq__(self, other):
3940
return True
4041

4142
def to_mlir(self, operands):
42-
assert "Not implemented"
43+
raise NotImplementedError()
4344

4445

4546
def test_is_structurally_equivalent():

tripy/tripy/backend/mlir/memref.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
@lru_cache(maxsize=None)
29-
def _cached_create_memref(shape: Sequence[int], dtype: str, device_kind: str, stream):
29+
def _cached_create_empty_memref(shape: Sequence[int], dtype: str, device_kind: str, stream):
3030
mlirtrt_device = mlir_utils.MLIRRuntimeClient().get_devices()[0] if device_kind == "gpu" else None
3131
mlir_dtype = mlir_utils.convert_tripy_dtype_to_runtime_dtype(dtype)
3232
return mlir_utils.MLIRRuntimeClient().create_memref(
@@ -55,9 +55,10 @@ def create_empty_memref(
5555
5656
"""
5757
if use_cache:
58-
return _cached_create_memref(tuple(shape), dtype, device.kind, stream)
58+
assert common_utils.is_shape_empty(shape)
59+
return _cached_create_empty_memref(tuple(shape), dtype, device.kind, stream)
5960
else:
60-
return _cached_create_memref.__wrapped__(tuple(shape), dtype, device.kind, stream)
61+
return _cached_create_empty_memref.__wrapped__(tuple(shape), dtype, device.kind, stream)
6162

6263

6364
def create_memref_view(data):

tripy/tripy/common/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ def is_empty(data: Sequence) -> bool:
7474
return isinstance(data, Sequence) and all(map(is_empty, data))
7575

7676

77+
def is_shape_empty(shape: Sequence[int]) -> bool:
78+
"""
79+
A shape is considered empty if any of its dimensions is zero.
80+
81+
Args:
82+
shape (Tuple[int, ...]): A tuple representing the shape of a tensor.
83+
84+
Returns:
85+
bool: True if the shape represents an empty tensor, False otherwise.
86+
"""
87+
return any(dim == 0 for dim in shape)
88+
89+
7790
class Float16MemoryView:
7891
"""
7992
A custom memory view class for handling float16 data.

tripy/tripy/flat_ir/flat_ir.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self, shapes: Sequence[ShapeBounds] = None):
4343

4444
def __str__(self) -> str:
4545
"""Generate a string representation of the FlatIR."""
46-
from tripy.flat_ir.ops.base import FlatIRFunction, BaseFlatIROp
46+
from tripy.flat_ir.function import FlatIRFunction
47+
from tripy.flat_ir.ops.base import BaseFlatIROp
4748

4849
ir = []
4950

@@ -96,7 +97,8 @@ def to_mlir(self):
9697
from mlir_tensorrt.compiler import ir
9798
from mlir_tensorrt.compiler.dialects import func as func_dialect
9899
from tripy.backend.mlir.utils import make_ir_context, make_tensor_location
99-
from tripy.flat_ir.ops.base import FlatIRFunction, BaseFlatIROp
100+
from tripy.flat_ir.function import FlatIRFunction
101+
from tripy.flat_ir.ops.base import BaseFlatIROp
100102

101103
def _base_op_to_mlir(op, mlir_tensor_map):
102104
op_inputs = [mlir_tensor_map[input_tensor.name] for input_tensor in op.inputs]
@@ -376,7 +378,8 @@ def integrate_subgraph(self, inputs: List["FlatIRTensor"], outputs: List["FlatIR
376378
"""
377379
Integrate a subgraph delineated by the given inputs and outputs into this FlatIR.
378380
"""
379-
from tripy.flat_ir.ops.base import BaseFlatIROp, FlatIRFunction
381+
from tripy.flat_ir.function import FlatIRFunction
382+
from tripy.flat_ir.ops.base import BaseFlatIROp
380383
from tripy.flat_ir.tensor import FlatIRTensor
381384
from tripy.flat_ir.ops import ConstantOp
382385

0 commit comments

Comments
 (0)