Skip to content

Commit 5b12456

Browse files
committed
Address review comments
1 parent f24352b commit 5b12456

File tree

14 files changed

+164
-130
lines changed

14 files changed

+164
-130
lines changed

tripy/tests/flat_ir/ops/test_broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@ def test_str(self):
3333
broadcast = func_broadcast.ops[-1]
3434
assert isinstance(broadcast, DynamicBroadcastOp)
3535
assert re.match(
36-
r"t_inter3: \[rank=\(2\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter4, t_inter2, broadcast_dim=\[\]\)",
36+
r"t_inter[0-9]+: \[rank=\(2\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[\]\)",
3737
str(broadcast),
3838
)

tripy/tests/flat_ir/ops/test_gather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tripy as tp
2020

2121
from tripy.flat_ir.ops import DynamicGatherOp
22-
from tripy.flat_ir.ops.base import FlatIRFunction
22+
from tripy.flat_ir.function import FlatIRFunction
2323
from tripy.frontend.trace import Trace
2424
import re
2525

tripy/tests/flat_ir/ops/test_maximum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import re
1919
import tripy as tp
2020
from tripy.frontend.trace import Trace
21+
from tripy.flat_ir.function import FlatIRFunction
2122
from tripy.flat_ir.ops import MaxOp
22-
from tripy.flat_ir.ops.base import FlatIRFunction
2323

2424

2525
class TestMaxOp:

tripy/tests/flat_ir/ops/test_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
import tripy as tp
1919
from tripy.frontend.trace import Trace
20+
from tripy.flat_ir.function import FlatIRFunction
2021
from tripy.flat_ir.ops import ArgMinMaxOp, ConvertOp, DivideOp, DynamicBroadcastOp, MulOp, ReduceOp
21-
from tripy.flat_ir.ops.base import FlatIRFunction
2222
import re
2323

2424

tripy/tests/flat_ir/ops/test_subtract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import re
1919
import tripy as tp
2020
from tripy.frontend.trace import Trace
21+
from tripy.flat_ir.function import FlatIRFunction
2122
from tripy.flat_ir.ops import SubtractOp
22-
from tripy.flat_ir.ops.base import FlatIRFunction
2323

2424

2525
class TestSubtractOp:

tripy/tests/flat_ir/test_constant_deduplication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import pytest
1717
from tripy.flat_ir.flat_ir import FlatIR
18-
from tripy.flat_ir.ops.base import FlatIRFunction
18+
from tripy.flat_ir.function import FlatIRFunction
1919
from tripy.flat_ir.ops import ConstantOp
2020
from tripy.flat_ir.tensor import FlatIRTensor
2121
from tripy.common.device import device

tripy/tests/flat_ir/test_function_deduplication.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from typing import List, Optional
1919

2020
from tripy.flat_ir.flat_ir import FlatIR
21-
from tripy.flat_ir.ops.base import FlatIRFunction, BaseFlatIROp
21+
from tripy.flat_ir.function import FlatIRFunction
22+
from tripy.flat_ir.ops.base import BaseFlatIROp
2223
from tripy.flat_ir.ops import ConstantOp
2324
from tripy.flat_ir.tensor import FlatIRTensor
2425
from tripy.common.device import device
@@ -39,7 +40,7 @@ def __eq__(self, other):
3940
return True
4041

4142
def to_mlir(self, operands):
42-
assert "Not implemented"
43+
raise NotImplementedError()
4344

4445

4546
def test_is_structurally_equivalent():

tripy/tripy/backend/mlir/memref.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,6 @@
2525
import mlir_tensorrt.runtime.api as runtime
2626

2727

28-
@lru_cache(maxsize=None)
29-
def _cached_create_memref(shape: Sequence[int], dtype: str, device_kind: str, stream):
30-
mlirtrt_device = mlir_utils.MLIRRuntimeClient().get_devices()[0] if device_kind == "gpu" else None
31-
mlir_dtype = mlir_utils.convert_tripy_dtype_to_runtime_dtype(dtype)
32-
return mlir_utils.MLIRRuntimeClient().create_memref(
33-
shape=list(shape),
34-
dtype=mlir_dtype,
35-
device=mlirtrt_device,
36-
stream=stream,
37-
)
38-
39-
4028
def create_empty_memref(
4129
shape: Sequence[int],
4230
dtype: str,
@@ -54,7 +42,20 @@ def create_empty_memref(
5442
Defaults to True. This ensures we reuse empty memref across functions.
5543
5644
"""
45+
46+
@lru_cache(maxsize=None)
47+
def _cached_create_memref(shape: Sequence[int], dtype: str, device_kind: str, stream):
48+
mlirtrt_device = mlir_utils.MLIRRuntimeClient().get_devices()[0] if device_kind == "gpu" else None
49+
mlir_dtype = mlir_utils.convert_tripy_dtype_to_runtime_dtype(dtype)
50+
return mlir_utils.MLIRRuntimeClient().create_memref(
51+
shape=list(shape),
52+
dtype=mlir_dtype,
53+
device=mlirtrt_device,
54+
stream=stream,
55+
)
56+
5757
if use_cache:
58+
assert common_utils.is_shape_empty(shape)
5859
return _cached_create_memref(tuple(shape), dtype, device.kind, stream)
5960
else:
6061
return _cached_create_memref.__wrapped__(tuple(shape), dtype, device.kind, stream)

tripy/tripy/common/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ def is_empty(data: Sequence) -> bool:
7474
return isinstance(data, Sequence) and all(map(is_empty, data))
7575

7676

77+
def is_shape_empty(shape: Sequence[int]) -> bool:
78+
"""
79+
A shape is considered empty if any of its dimensions is zero.
80+
81+
Args:
82+
shape (Tuple[int, ...]): A tuple representing the shape of a tensor.
83+
84+
Returns:
85+
bool: True if the shape represents an empty tensor, False otherwise.
86+
"""
87+
return any(dim == 0 for dim in shape)
88+
89+
7790
class Float16MemoryView:
7891
"""
7992
A custom memory view class for handling float16 data.

tripy/tripy/flat_ir/flat_ir.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self, shapes: Sequence[ShapeBounds] = None):
4343

4444
def __str__(self) -> str:
4545
"""Generate a string representation of the FlatIR."""
46-
from tripy.flat_ir.ops.base import FlatIRFunction, BaseFlatIROp
46+
from tripy.flat_ir.function import FlatIRFunction
47+
from tripy.flat_ir.ops.base import BaseFlatIROp
4748

4849
ir = []
4950

@@ -96,7 +97,8 @@ def to_mlir(self):
9697
from mlir_tensorrt.compiler import ir
9798
from mlir_tensorrt.compiler.dialects import func as func_dialect
9899
from tripy.backend.mlir.utils import make_ir_context, make_tensor_location
99-
from tripy.flat_ir.ops.base import FlatIRFunction, BaseFlatIROp
100+
from tripy.flat_ir.function import FlatIRFunction
101+
from tripy.flat_ir.ops.base import BaseFlatIROp
100102

101103
def _base_op_to_mlir(op, mlir_tensor_map):
102104
op_inputs = [mlir_tensor_map[input_tensor.name] for input_tensor in op.inputs]
@@ -376,7 +378,8 @@ def integrate_subgraph(self, inputs: List["FlatIRTensor"], outputs: List["FlatIR
376378
"""
377379
Integrate a subgraph delineated by the given inputs and outputs into this FlatIR.
378380
"""
379-
from tripy.flat_ir.ops.base import BaseFlatIROp, FlatIRFunction
381+
from tripy.flat_ir.function import FlatIRFunction
382+
from tripy.flat_ir.ops.base import BaseFlatIROp
380383
from tripy.flat_ir.tensor import FlatIRTensor
381384
from tripy.flat_ir.ops import ConstantOp
382385

0 commit comments

Comments
 (0)