Skip to content

Commit f24352b

Browse files
committed
Add function deduplication
1 parent ecd0a28 commit f24352b

File tree

13 files changed

+647
-274
lines changed

13 files changed

+647
-274
lines changed

tripy/tests/flat_ir/ops/test_gather.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_gather_mlir(self, axis):
5555
flat_ir = trace.to_flat_ir()
5656
mlir_text = str(flat_ir.to_mlir())
5757
if axis == 0:
58-
target = '"stablehlo.dynamic_gather"(%arg0, %arg1, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>}> : (tensor<2x3xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<1x?xi32>'
58+
target = '"stablehlo.dynamic_gather"(%arg0, %arg1, %6) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>}> : (tensor<?x?xi32>, tensor<?xi32>, tensor<2xi32>) -> tensor<?x?xi32>'
5959
else:
60-
target = '"stablehlo.dynamic_gather"(%arg0, %arg1, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<2x3xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<?x1xi32>'
60+
target = '"stablehlo.dynamic_gather"(%arg0, %arg1, %6) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<?x?xi32>, tensor<?xi32>, tensor<2xi32>) -> tensor<?x?xi32>'
6161
assert target in mlir_text, mlir_text

tripy/tests/flat_ir/ops/test_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_str(self, flat_ir):
5959

6060
def test_mlir(self, flat_ir):
6161
assert """
62-
tensorrt.opaque_plugin {creator_params = {output_height = 5 : i32, output_width = 5 : i32}, plugin_name = "ROIAlign_TRT", plugin_namespace = "", plugin_version = "1"}(%0, %cst, %1) : (tensor<?x?x?x?xf32>, tensor<2x4xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
62+
tensorrt.opaque_plugin {creator_params = {output_height = 5 : i32, output_width = 5 : i32}, plugin_name = "ROIAlign_TRT", plugin_namespace = "", plugin_version = "1"}(%0, %cst, %2) : (tensor<?x?x?x?xf32>, tensor<2x4xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
6363
""".strip() in str(
6464
flat_ir.to_mlir()
6565
)

tripy/tests/flat_ir/test_constant_deduplication.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,10 @@ def test_integrate_subgraph_constant_deduplication(config):
9292
mock_op = [op for op in ops if isinstance(op, MockOp)][0]
9393
assert mock_op.inputs[0] is mock_op.inputs[1], "The mock op should use the same tensor for its first two inputs"
9494
assert mock_op.inputs[0] is not mock_op.inputs[2], "The mock op should still have a different third input"
95+
96+
if config == "main":
97+
# Verify that tensor replacements were applied
98+
assert len(flat_ir.tensor_replacements) > 0, "There should be tensor replacements after integration"
99+
100+
# Verify that the constant map has the correct number of entries
101+
assert len(flat_ir.constant_map) == 2, "Constant map should have 2 entries"
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
import re
17+
from dataclasses import dataclass
18+
from typing import List, Optional
19+
20+
from tripy.flat_ir.flat_ir import FlatIR
21+
from tripy.flat_ir.ops.base import FlatIRFunction, BaseFlatIROp
22+
from tripy.flat_ir.ops import ConstantOp
23+
from tripy.flat_ir.tensor import FlatIRTensor
24+
from tripy.common.device import device
25+
from tripy.common.datatype import float32, int32
26+
27+
28+
@dataclass(repr=False, eq=False)
29+
class MockOp(BaseFlatIROp):
30+
def __init__(self, inputs, outputs):
31+
self.inputs = inputs
32+
self.outputs = outputs
33+
self.trace_input_names = []
34+
self.trace_output_names = []
35+
for output in outputs:
36+
output.producer = self
37+
38+
def __eq__(self, other):
39+
return True
40+
41+
def to_mlir(self, operands):
42+
assert "Not implemented"
43+
44+
45+
def test_is_structurally_equivalent():
46+
"""Test the structural equivalence of two FlatIR functions."""
47+
flat_ir = FlatIR()
48+
49+
def create_tensor(reason_details: str, name: Optional[str] = None) -> FlatIRTensor:
50+
"""Create and register a FlatIRTensor."""
51+
t = FlatIRTensor.build(
52+
shape=[3],
53+
rank=1,
54+
dtype=float32,
55+
device=device("gpu"),
56+
reason_details=reason_details,
57+
)
58+
if name:
59+
t.name = name
60+
flat_ir.register_tensor(t)
61+
return t
62+
63+
def create_function(
64+
name: str,
65+
input_tensor: FlatIRTensor,
66+
output_tensors: List[FlatIRTensor],
67+
) -> FlatIRFunction:
68+
"""Create a FlatIRFunction with associated operations."""
69+
callee_input = input_tensor.clone(reason_details=f"{name} input cloned from {input_tensor}")
70+
callee_outputs = [out.clone(reason_details=f"{name} output cloned from {out}") for out in output_tensors]
71+
72+
flat_ir.register_tensor(callee_input)
73+
setattr(callee_input, "caller_tensor", input_tensor)
74+
75+
for callee_out, original_out in zip(callee_outputs, output_tensors):
76+
flat_ir.register_tensor(callee_out)
77+
setattr(callee_out, "caller_tensor", original_out)
78+
79+
func = FlatIRFunction(name, [callee_input], callee_outputs)
80+
mock_op = MockOp([callee_input], [callee_outputs[0]])
81+
const_op = ConstantOp.build([], [callee_outputs[1]], data=[3, 4, 5])
82+
callee_outputs[1].producer = const_op
83+
84+
func.ops.extend([mock_op, const_op])
85+
for out in output_tensors:
86+
out.producer = func
87+
88+
return func
89+
90+
# Create main tensors
91+
input_tensor = create_tensor("Function 1 input", "main_input_tensor")
92+
intermediates = [create_tensor(f"Function 1 output {i}", f"intermediate_tensor_{i}") for i in range(2)]
93+
outputs = [create_tensor(f"Function 2 output {i}", f"main_output_tensor_{i}") for i in range(2)]
94+
95+
# Create two structurally equivalent functions
96+
func_1 = create_function("Func1", input_tensor, intermediates)
97+
func_2 = create_function("Func2", intermediates[0], outputs)
98+
99+
# Assert structural equivalence
100+
assert func_1.is_structurally_equivalent(func_2)
101+
102+
# Set up FlatIR inputs and outputs
103+
flat_ir.inputs = [input_tensor]
104+
flat_ir.outputs = outputs
105+
106+
# Integrate subgraphs
107+
for in_tensor, out_tensors in [(input_tensor, intermediates), (intermediates[0], outputs)]:
108+
flat_ir.integrate_subgraph([in_tensor], out_tensors)
109+
110+
flat_ir_str = str(flat_ir)
111+
112+
# Check Func1 structure
113+
func_pattern = re.compile(r"function\s+Func1\s*\(\s*\w+:.*?\)\s*->\s*\(.*?\)\s*{.*?return.*?}", re.DOTALL)
114+
assert func_pattern.search(flat_ir_str), "Function Func1 structure is incorrect"
115+
116+
# Check Main Function structure
117+
main_pattern = re.compile(
118+
r"Main Function:.*?inputs:.*?=\s*function Func1.*?=\s*function Func1.*?outputs:", re.DOTALL
119+
)
120+
assert main_pattern.search(flat_ir_str), "Main Function structure is incorrect"
121+
122+
print("All assertions passed. Function structures are correct.")

tripy/tests/helper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def raises(ExcType: type, match: Optional[str] = None, has_stack_info_for: Seque
7676
for tensor in has_stack_info_for:
7777
# Stack info is indented since it's part of the `details` block in `raise_error`
7878
expected_stack_info = indent(_make_stack_info_message(tensor.stack_info).strip(), " " * 4)
79-
# TODO: How to add stack information for broadcasted tensors.
80-
# assert expected_stack_info in error_msg, f"Missing stack information for tensor:\n{expected_stack_info}"
79+
assert expected_stack_info in error_msg, f"Missing stack information for tensor:\n{expected_stack_info}"
8180

8281

8382
@contextlib.contextmanager

tripy/tripy/backend/mlir/executor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,11 @@ def execute(self, output_devices=List[device], inputs: List["Tensor"] = []) -> L
174174
# Allocate output memory and store buffer pointers.
175175
outputs = [
176176
create_empty_memref(
177-
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
177+
shape=info.shape,
178+
dtype=info.dtype,
179+
device=info.device,
180+
stream=self.stream._active_cuda_stream,
181+
use_cache=False,
178182
)
179183
for info in out_tensor_info
180184
]

tripy/tripy/backend/mlir/memref.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616
#
1717

18+
from functools import lru_cache
19+
from typing import Sequence
1820

1921
from tripy.backend.mlir import utils as mlir_utils
2022
from tripy.common import device as tp_device
@@ -23,11 +25,9 @@
2325
import mlir_tensorrt.runtime.api as runtime
2426

2527

26-
def create_empty_memref(shape, dtype, device=tp_device("gpu"), stream=None):
27-
"""
28-
Creates an empty memref, used for allocating memory.
29-
"""
30-
mlirtrt_device = mlir_utils.MLIRRuntimeClient().get_devices()[0] if device == tp_device("gpu") else None
28+
@lru_cache(maxsize=None)
29+
def _cached_create_memref(shape: Sequence[int], dtype: str, device_kind: str, stream):
30+
mlirtrt_device = mlir_utils.MLIRRuntimeClient().get_devices()[0] if device_kind == "gpu" else None
3131
mlir_dtype = mlir_utils.convert_tripy_dtype_to_runtime_dtype(dtype)
3232
return mlir_utils.MLIRRuntimeClient().create_memref(
3333
shape=list(shape),
@@ -37,11 +37,33 @@ def create_empty_memref(shape, dtype, device=tp_device("gpu"), stream=None):
3737
)
3838

3939

40+
def create_empty_memref(
41+
shape: Sequence[int],
42+
dtype: str,
43+
device: tp_device = tp_device("gpu"),
44+
stream=None,
45+
use_cache: bool = True,
46+
):
47+
"""
48+
Creates an empty memref, used for allocating memory.
49+
Caches the result for subsequent calls with the same parameters.
50+
51+
Args:
52+
use_cache (bool, optional): Whether to use cached results for repeated calls with the same parameters.
53+
If True, returns cached results if available. If False, always creates a new memref.
54+
Defaults to True. This ensures we reuse empty memref across functions.
55+
56+
"""
57+
if use_cache:
58+
return _cached_create_memref(tuple(shape), dtype, device.kind, stream)
59+
else:
60+
return _cached_create_memref.__wrapped__(tuple(shape), dtype, device.kind, stream)
61+
62+
4063
def create_memref_view(data):
4164
"""
4265
Creates a memref view of an array object that implements the dlpack interface.
4366
"""
44-
4567
return mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
4668

4769

tripy/tripy/backend/mlir/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ def parse_tensor_names_from_location(msg: str) -> Tuple[List[str], List[str], Li
228228
if not loc:
229229
return [], [], [], [], msg
230230

231+
# Hack: Extract callsite for function call locations.
232+
if "at" in loc:
233+
_, _, loc = loc.partition('at "')
231234
input_names, _, loc = loc.partition(OUTPUT_SEPARATOR)
232235
output_names, _, loc = loc.partition(TRACE_INPUTS_SEPARATOR)
233236
trace_inputs, _, trace_outputs = loc.partition(TRACE_OUTPUTS_SEPARATOR)
@@ -306,6 +309,33 @@ def is_any_dim_dynamic(mlir_tensor):
306309
return any([type.is_dynamic_dim(i) for i in range(type.rank)])
307310

308311

312+
def has_all_dynamic_dims(tensor_type: ir.RankedTensorType) -> bool:
313+
"""Check if all dimensions of a tensor type are dynamic."""
314+
if not isinstance(tensor_type, ir.RankedTensorType):
315+
raise ValueError("Input must be a RankedTensorType")
316+
317+
return all(dim == ir.ShapedType.get_dynamic_size() for dim in tensor_type.shape)
318+
319+
320+
def cast_to_dynamic_ranked_tensor(input_tensor: ir.Value, always_insert_cast: bool = False) -> ir.Value:
321+
"""Cast a tensor to a dynamic ranked tensor if necessary."""
322+
from mlir_tensorrt.compiler.dialects._ods_common import get_op_result_or_value
323+
from mlir_tensorrt.compiler.dialects import stablehlo
324+
325+
input_type = get_op_result_or_value(input_tensor).type
326+
327+
if not ir.RankedTensorType.isinstance(input_type):
328+
raise ValueError("Input must be a RankedTensorType")
329+
330+
if not always_insert_cast and has_all_dynamic_dims(input_type):
331+
return input_tensor
332+
333+
dynamic_shape = [ir.ShapedType.get_dynamic_size()] * input_type.rank
334+
dynamic_type = ir.RankedTensorType.get(dynamic_shape, input_type.element_type)
335+
336+
return stablehlo.ConvertOp(result=dynamic_type, operand=input_tensor).result
337+
338+
309339
class ShapeContext:
310340
_instance = None
311341

0 commit comments

Comments
 (0)