Skip to content

Commit f676436

Browse files
Makes various optimizations
Applies various optimizations based on data from NSight traces. For example, this includes removing all assertions in the critical path and reducing the amount of Python code involved as much as possible.
1 parent 4e65e58 commit f676436

File tree

13 files changed

+94
-142
lines changed

13 files changed

+94
-142
lines changed

tripy/tests/performance/cases/linear_block.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
)
2929
def linear_block(tripy_dtype, torch_dtype):
3030

31+
NUM_LAYERS = 15
32+
3133
class LinearBlock(tp.Module):
3234
def __init__(self):
33-
self.layers = [tp.Linear(256, 256, bias=False, dtype=tripy_dtype) for _ in range(10)]
35+
self.layers = [tp.Linear(256, 256, bias=False, dtype=tripy_dtype) for _ in range(NUM_LAYERS)]
3436
for layer in self.layers:
3537
# Adjust the weights to prevent FP16 overflows:
3638
weight = np.tile(np.array([[-1, 1], [1, -1]], dtype=TRIPY_TO_NUMPY[tripy_dtype]), (128, 128))
@@ -47,7 +49,7 @@ def __init__(self):
4749
self.layers = torch.nn.ModuleList(
4850
[
4951
torch.nn.Linear(256, 256, bias=False, dtype=torch_dtype, device=torch.device("cuda"))
50-
for _ in range(10)
52+
for _ in range(NUM_LAYERS)
5153
]
5254
)
5355

tripy/tripy/backend/api/executable.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tripy.common.exception import raise_error
2626
from tripy.frontend import Tensor
2727
from tripy.utils import json as json_utils
28+
from tripy.utils.stack_info import StackInfo
2829

2930

3031
@export.public_api(document_under="compiling_code")
@@ -41,6 +42,7 @@ def __init__(self, executable, arg_names, output_devices):
4142
self._executable = executable
4243
self._executor = Executor(self._executable)
4344
self._arg_names = arg_names
45+
self._num_expected_args = len(arg_names)
4446
self._output_devices = output_devices
4547
self._executable_signature = self._executable.get_signature("main")
4648

@@ -88,12 +90,12 @@ def add(a, b):
8890
8991
out = compiled_add(a, b)
9092
"""
91-
NUM_ARGS = len(args) + len(kwargs)
93+
num_positional = len(args)
94+
NUM_ARGS = num_positional + len(kwargs)
9295

93-
input_tensors = []
94-
input_tensors.extend(args)
96+
input_tensors = list(args)
9597
# Need to get arguments in the order of self._arg_names, which may be different from kwargs ordering.
96-
expected_kwargs = self._arg_names[len(args) :]
98+
expected_kwargs = self._arg_names[num_positional:]
9799
for name in expected_kwargs:
98100
if name not in kwargs:
99101
raise_error(f"Missing argument: {name}", [f"Expected the following arguments: {self._arg_names}"])
@@ -106,16 +108,17 @@ def add(a, b):
106108
f"Extra keyword arguments: {list(kwargs.keys())}",
107109
[
108110
f"Expected the following arguments: {self._arg_names}.\n"
109-
f"Note: The following arguments were already provided as positional arguments: {self._arg_names[:len(args)]}"
111+
f"Note: The following arguments were already provided as positional arguments: {self._arg_names[:num_positional]}"
110112
],
111113
)
112114

113115
# We do this after kwarg checks since those will be more informative (we can explain which arguments are missing/extra).
114-
if NUM_ARGS != len(self._arg_names):
116+
117+
if NUM_ARGS != self._num_expected_args:
115118
raise_error(
116119
"Incorrect number of arguments.",
117120
[
118-
f"Expected {len(self._arg_names)} arguments but got {NUM_ARGS}.\n"
121+
f"Expected {self._num_expected_args} arguments but got {NUM_ARGS}.\n"
119122
f"Note: Expected arguments were: {self._arg_names}",
120123
],
121124
)
@@ -158,8 +161,6 @@ def add(a, b):
158161
)
159162
raise
160163

161-
from tripy.utils.stack_info import StackInfo
162-
163164
output_tensors = [Tensor(output, stack_info=StackInfo([])) for output in executor_outputs]
164165
if len(output_tensors) == 1:
165166
output_tensors = output_tensors[0]

tripy/tripy/backend/mlir/executor.py

Lines changed: 32 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,34 @@
1717

1818
from typing import List
1919

20-
import mlir_tensorrt.compiler.api as compiler
2120
import mlir_tensorrt.runtime.api as runtime
2221

22+
from tripy.backend.api.stream import default_stream
2323
from tripy.backend.mlir.memref import create_empty_memref
24+
from tripy.backend.mlir.utils import MLIRRuntimeClient, convert_runtime_dtype_to_tripy_dtype
2425
from tripy.backend.utils import TensorInfo
2526
from tripy.common import datatype, device
2627
from tripy.common.exception import raise_error
28+
from tripy.common.utils import convert_list_to_array
2729
from tripy.utils import make_tuple
2830

2931

3032
class Executor:
3133
def __init__(self, executable: runtime.Executable) -> None:
32-
from tripy.backend.api.stream import default_stream
33-
from tripy.backend.mlir.utils import MLIRRuntimeClient
34-
3534
self.runtime_client = MLIRRuntimeClient()
3635
session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0)
3736
self.session = runtime.RuntimeSession(session_options, executable)
3837
self.device = self.runtime_client.get_devices()[0] # Assume a single device is available.
3938
self.signature = executable.get_signature("main")
4039
self.stream = default_stream()
40+
self.num_input_args = self.signature.get_num_input_args()
41+
self.num_output_args = self.signature.get_num_output_args()
42+
self.output_args = [
43+
self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args)
44+
]
45+
self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args]
4146

4247
def _create_shape_memref(self, shape):
43-
from tripy.common.utils import convert_list_to_array
44-
4548
shape = make_tuple(shape)
4649
if len(shape) == 0:
4750
# create an empty memref
@@ -55,34 +58,21 @@ def _create_shape_memref(self, shape):
5558
stream=self.stream._active_cuda_stream,
5659
)
5760

58-
def _get_inputs_runtime_shape(self, inputs):
59-
inputs_shape = []
60-
for input in inputs:
61-
inputs_shape.append(input.trace_tensor.producer.data.shape)
62-
return inputs_shape
63-
6461
def _get_outputs_shape(self):
65-
offset = self.signature.get_num_input_args()
6662
outputs_shape = []
6763
all_outputs_known = True
68-
for output_index in range(self.signature.get_num_output_args()):
69-
arg_index = output_index + offset
70-
arg = self.signature.get_arg(arg_index)
71-
assert compiler.MemRefType.isinstance(arg)
72-
memref = runtime.MemRefType(arg)
73-
rank = len(memref.shape)
74-
64+
for memref in self.output_memrefs:
7565
outputs_shape.append(memref.shape)
76-
if rank > 0:
77-
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
66+
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
7867
return outputs_shape, all_outputs_known
7968

80-
def _execute_shape_inference(self, inputs_shape, outputs_shape):
81-
# Only execute shape inference if shape function name is valid.
82-
assert (
83-
self.signature.get_shape_func_name()
84-
), f"Shape inference function is missing while output shapes are not known."
69+
def _get_inputs_runtime_shape(self, inputs):
70+
inputs_shape = []
71+
for input in inputs:
72+
inputs_shape.append(input.trace_tensor.producer.data.shape)
73+
return inputs_shape
8574

75+
def _execute_shape_inference(self, inputs_shape, outputs_shape):
8676
inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape]
8777
outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape]
8878
self.session.execute_function(
@@ -93,41 +83,24 @@ def _execute_shape_inference(self, inputs_shape, outputs_shape):
9383
return outputs_runtime_shape
9484

9585
def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
96-
from tripy.backend.mlir.utils import convert_runtime_dtype_to_tripy_dtype
97-
98-
offset = self.signature.get_num_input_args()
9986
outputs_tensor_info = []
100-
for output_index in range(self.signature.get_num_output_args()):
101-
arg_index = output_index + offset
102-
arg = self.signature.get_arg(arg_index)
103-
assert compiler.MemRefType.isinstance(arg) or compiler.ScalarType.isinstance(
104-
arg
105-
), "Argument must be either MemRefType or ScalarType"
106-
assert compiler.MemRefType.isinstance(
107-
arg
108-
), "ScalarType argument are not yet supported" # 158: Add scalar type output argument support.
109-
memref = compiler.MemRefType(arg)
87+
for index in range(self.num_output_args):
88+
memref = self.output_memrefs[index]
11089
dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype)
111-
device_type = "gpu" if memref.address_space == runtime.PointerType.device else "cpu"
112-
if output_devices[output_index]:
113-
device_type = output_devices[output_index].kind
114-
is_static_shape = all(dim >= 0 for dim in memref.shape)
115-
if is_static_shape:
116-
outputs_tensor_info.append(
117-
TensorInfo(len(memref.shape), tuple(memref.shape), dtype, device(device_type))
118-
)
119-
else:
120-
runtime_shape = [
121-
rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[output_index])
122-
]
123-
outputs_tensor_info.append(
124-
TensorInfo(
125-
len(runtime_shape),
126-
tuple(runtime_shape),
127-
dtype,
128-
device(device_type),
129-
)
90+
91+
output_device = output_devices[index]
92+
if not output_device:
93+
output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0))
94+
95+
runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])]
96+
outputs_tensor_info.append(
97+
TensorInfo(
98+
len(runtime_shape),
99+
tuple(runtime_shape),
100+
dtype,
101+
output_device,
130102
)
103+
)
131104
return outputs_tensor_info
132105

133106
def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]):

tripy/tripy/backend/mlir/memref.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _cached_create_empty_memref(shape: Sequence[int], dtype: str, device_kind: s
4141
def create_empty_memref(
4242
shape: Sequence[int],
4343
dtype: str,
44-
device: tp_device = tp_device("gpu"),
44+
device: tp_device = tp_device(("gpu", 0)),
4545
stream=None,
4646
use_cache: bool = True,
4747
):

tripy/tripy/backend/mlir/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,17 @@ def redirect_stderr() -> BinaryIO:
288288

289289

290290
def convert_tripy_dtype_to_runtime_dtype(dtype: datatype.dtype) -> runtime.ScalarTypeCode:
291-
if dtype not in TRIPY_DTYPE_TO_MLIR_TRT:
291+
try:
292+
return TRIPY_DTYPE_TO_MLIR_TRT[dtype]
293+
except KeyError:
292294
raise_error(f"Data type: '{dtype}' does not have a corresponding runtime data type")
293-
return TRIPY_DTYPE_TO_MLIR_TRT.get(dtype)
294295

295296

296297
def convert_runtime_dtype_to_tripy_dtype(dtype: runtime.ScalarTypeCode) -> datatype.dtype:
297-
if dtype not in MLIR_TRT_TO_TRIPY_DTYPE:
298+
try:
299+
return MLIR_TRT_TO_TRIPY_DTYPE[dtype]
300+
except KeyError:
298301
raise_error(f"Data type: '{dtype}' does not have a corresponding tripy data type")
299-
return MLIR_TRT_TO_TRIPY_DTYPE.get(dtype)
300302

301303

302304
def is_any_dim_dynamic(mlir_tensor):

tripy/tripy/common/device.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from tripy.common.exception import TripyException
2323
from tripy.utils.json import Decoder, Encoder
2424

25+
_VALID_KINDS = {"cpu", "gpu"}
26+
2527

2628
@export.public_api()
2729
@dataclass
@@ -60,27 +62,30 @@ def __init__(self, device: str) -> None:
6062
assert gpu_1.kind == "gpu"
6163
assert gpu_1.index == 1
6264
"""
63-
64-
kind, _, index = device.partition(":")
65-
kind = kind.lower()
66-
67-
if index:
68-
try:
69-
index = int(index)
70-
except ValueError:
71-
raise TripyException(f"Could not interpret: {index} as an integer")
72-
else:
73-
index = 0
74-
75-
if index < 0:
76-
raise TripyException(f"Device index must be a non-negative integer, but was: {index}")
77-
78-
VALID_KINDS = {"cpu", "gpu"}
79-
if kind not in VALID_KINDS:
80-
raise TripyException(f"Unrecognized device kind: {kind}. Choose from: {list(VALID_KINDS)}")
81-
82-
self.kind = kind
83-
self.index = index
65+
try:
66+
# Fast constructor for the critical path. If a Tuple[str, int] is provided, then
67+
# we bypass all the logic to parse the information from a string.
68+
self.kind, self.index = device
69+
except ValueError:
70+
kind, _, index = device.partition(":")
71+
kind = kind.lower()
72+
73+
if index:
74+
try:
75+
index = int(index)
76+
except ValueError:
77+
raise TripyException(f"Could not interpret: {index} as an integer")
78+
else:
79+
index = 0
80+
81+
if index < 0:
82+
raise TripyException(f"Device index must be a non-negative integer, but was: {index}")
83+
84+
if kind not in _VALID_KINDS:
85+
raise TripyException(f"Unrecognized device kind: {kind}. Choose from: {list(_VALID_KINDS)}")
86+
87+
self.kind = kind
88+
self.index = index
8489

8590
def __str__(self) -> str:
8691
return f"{self.kind}:{self.index}"

tripy/tripy/frontend/tensor.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from textwrap import indent
1919
from typing import Any, List, Optional, Sequence, Union
2020

21+
import mlir_tensorrt.runtime.api as runtime
22+
2123
# Import ops to populate the registry before we define our Tensor class
2224
import tripy.frontend.ops
2325
import tripy.frontend.trace.ops
@@ -27,8 +29,7 @@
2729
from tripy.common.exception import raise_error
2830
from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY
2931
from tripy.frontend.trace.ops import Storage
30-
31-
import mlir_tensorrt.runtime.api as runtime
32+
from tripy.frontend.trace.tensor import TraceTensor
3233

3334

3435
class TensorMeta(type):
@@ -93,7 +94,6 @@ def __init__(
9394
9495
tensor = tp.Tensor([1.0, 2.0, 3.0], dtype=tp.float32)
9596
"""
96-
from tripy.frontend.trace.tensor import TraceTensor
9797

9898
# We include code for everything above the `BaseTraceOp.build` function, which is called at most
9999
# this many stack frames above the constructor.
@@ -120,17 +120,6 @@ def __init__(
120120
else:
121121
Storage.build_internal([], [self.trace_tensor], data, dtype, device)
122122

123-
# Storage should populate attrs of trace_tensor
124-
assert all(
125-
attr is not None
126-
for attr in [
127-
self.trace_tensor.shape,
128-
self.trace_tensor.dtype,
129-
self.trace_tensor.device,
130-
self.trace_tensor.producer,
131-
]
132-
)
133-
134123
# Explicit cast if necessary
135124
# TODO(#155): Add copy as well when host allocation is fixed
136125
# Also make device as a property, similar to dtype
@@ -172,13 +161,13 @@ def rank(self):
172161
return self.trace_tensor.rank
173162

174163
def eval(self) -> runtime.MemRefValue:
164+
if isinstance(self.trace_tensor.producer, Storage) and self.trace_tensor.producer.has_memref:
165+
return self.trace_tensor.producer.data
166+
175167
from tripy.backend.mlir.compiler import Compiler
176168
from tripy.backend.mlir.executor import Executor
177169
from tripy.frontend.trace import Trace
178170

179-
if isinstance(self.trace_tensor.producer, Storage) and self.trace_tensor.producer.has_memref:
180-
return self.trace_tensor.producer.data
181-
182171
trace = Trace([self])
183172
flat_ir = trace.to_flat_ir()
184173
mlir = flat_ir.to_mlir()

0 commit comments

Comments
 (0)