diff --git a/.github/workflows/tripy-l0.yml b/.github/workflows/tripy-l0.yml index d8fa8b38c..dd8b9218c 100644 --- a/.github/workflows/tripy-l0.yml +++ b/.github/workflows/tripy-l0.yml @@ -42,7 +42,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: build-new-container + - name: Build new container if: steps.filter.outputs.local_container == 'true' uses: docker/build-push-action@v6 with: @@ -50,7 +50,7 @@ jobs: tags: ${{ env.NEW_TEST_IMAGE }} push: false - - name: pull-latest-container + - name: Pull latest container if: steps.filter.outputs.local_container != 'true' run: docker pull ${{ env.l0_image }} @@ -63,10 +63,31 @@ jobs: python3 docs/generate_rsts.py sphinx-build build/doc_sources build/docs -c docs/ -j 4 -W -n - - name: run-test + - name: Run tests uses: addnab/docker-run-action@v3 with: image: ${{ env.l0_image }} options: --gpus all -v ${{ github.workspace }}/tripy:/tripy run: | - pytest --cov=tripy/ --cov-config=.coveragerc tests/ -v -m "not l1 and not manual" -n 4 --durations=15 + pytest --cov=tripy/ --cov-config=.coveragerc tests/ -v -m "not l1 and not manual" -n 4 --durations=15 --ignore tests/performance + + - name: Run performance benchmarks + uses: addnab/docker-run-action@v3 + with: + image: ${{ env.l0_image }} + options: --gpus all -v ${{ github.workspace }}/tripy:/tripy + run: | + pytest tests/performance -v -m "not l1 and not manual" --benchmark-warmup=on --benchmark-json benchmark.json + + - name: Store benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + tool: 'pytest' + output-file-path: ${{ github.workspace }}/tripy/benchmark.json + github-token: ${{ secrets.GITHUB_TOKEN }} + auto-push: true + # Show alert with commit comment on detecting possible performance regression + alert-threshold: '105%' + comment-on-alert: true + fail-on-alert: true + gh-pages-branch: benchmarks diff --git a/.github/workflows/tripy-l1.yml b/.github/workflows/tripy-l1.yml index 84c87ce05..8d18a84c7 100644 --- a/.github/workflows/tripy-l1.yml +++ b/.github/workflows/tripy-l1.yml @@ -53,4 +53,4 @@ jobs: - name: l1-test run: | cd /tripy/ - pytest --cov=tripy/ --cov-config=.coveragerc tests/ -v -m "not manual" -n 4 --durations=15 + pytest --cov=tripy/ --cov-config=.coveragerc tests/ -v -m "l1 and not manual" -n 4 --durations=15 --ignore tests/performance diff --git a/tripy/pyproject.toml b/tripy/pyproject.toml index e8ec89784..5f247795e 100644 --- a/tripy/pyproject.toml +++ b/tripy/pyproject.toml @@ -55,6 +55,10 @@ test = [ "pytest-profiling==1.7.0", "pytest-cov==4.1.0", "pytest-xdist==3.6.1", + "pytest-benchmark==4.0.0", + "pytest-lazy-fixture==0.6.3", + # Triton is required for torch.compile + "triton==3.0.0", "snakeviz==2.2.0", "coverage==7.4.1", "vulture==2.11", diff --git a/tripy/tests/README.md b/tripy/tests/README.md index 5ae218c9d..48e16f895 100644 --- a/tripy/tests/README.md +++ b/tripy/tests/README.md @@ -10,15 +10,20 @@ The `tests/integration` directory captures the latter group of tests. You can run all tests locally in the development container by running: ```bash -pytest tests/ -v +pytest tests/ -v -n 4 --dist worksteal --ignore tests/performance +pytest tests/performance -v ``` +Performance tests are run separately because they must run serially to ensure +accurate measurements. + You can also provide marker arguments to only run specific test cadences (see [the test cadence section](#test-cadence) below). For example, to run only L0 tests, use: ```bash -pytest tests/ -v -m "not l1 and not manual" -n 4 +pytest tests/ -v -m "not l1 and not manual" -n 4 --dist worksteal --ignore tests/performance +pytest tests/performance -v -m "not l1 and not manual" ``` @@ -56,7 +61,7 @@ http://localhost:8080/snakeviz/%2Ftripy%2Fprof%2Fcombined.prof You can generate code coverage reports locally by running: ```bash -pytest --cov=tripy/ --cov-report=html --cov-config=.coveragerc tests/ -n 4 -v +pytest --cov=tripy/ --cov-report=html --cov-config=.coveragerc tests/ -v ``` To view the report, open the `htmlcov/index.html` file from the root directory in a browser. @@ -125,3 +130,26 @@ Any caption other than `Example` will have a prefix of `Example: ` prepended to **NOTE: The docstrings must *not* import `tripy`, `numpy`, or `torch`. They will be imported** **automatically as `tp`, `np`, and `torch` respectively. Any other modules will need to be imported.** + + +### Performance Tests + +In addition to functional tests, we also run performance tests of three kinds: + +1. Regression tests, which compare current Tripy performance to historical data + to ensure we don't regress. We use the + [`pytest-benchmark`](https://pytest-benchmark.readthedocs.io/en/latest/) + plugin to gather data and the + [Continuous Benchmark GitHub Action](https://github.com/marketplace/actions/continuous-benchmark) + for regression testing. + + You can view graphs and charts of the historical data by opening the + [`index.html` file from the `benchmarks` branch](https://github.com/NVIDIA/TensorRT-Incubator/blob/benchmarks/dev/bench/index.html) + in a browser. + +2. Comparative tests, which compare Tripy and `torch.compile`. + +3. Overhead tests, which check the overhead introduced by Tripy as compared + to running the underlying MLIR executable by itself. This is done by measuring + how long it takes to run an empty executable since in that case, all the time + is taken by the Tripy wrapper code. diff --git a/tripy/tests/backend/api/test_executable.py b/tripy/tests/backend/api/test_executable.py index f94a6a2e9..3b588b463 100644 --- a/tripy/tests/backend/api/test_executable.py +++ b/tripy/tests/backend/api/test_executable.py @@ -66,12 +66,12 @@ def test_kwargs(self, single_return_executable): ( [tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32)], {"b": tp.ones((2, 2), dtype=tp.float32)}, - "Extra keyword arguments: \['b'\]", + r"Extra keyword arguments: \['b'\]", ), ( [tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32)], {"c": tp.ones((2, 2), dtype=tp.float32)}, - "Extra keyword arguments: \['c'\]", + r"Extra keyword arguments: \['c'\]", ), ], ) diff --git a/tripy/tests/frontend/module/test_module.py b/tripy/tests/frontend/module/test_module.py index fdef2d139..ef1022918 100644 --- a/tripy/tests/frontend/module/test_module.py +++ b/tripy/tests/frontend/module/test_module.py @@ -46,7 +46,7 @@ def test_get_set_attr(self, network): def test_incompatible_parameter_cannot_be_set(self, network): with helper.raises( - tp.TripyException, match="New parameter shape: \[2, 3\] is not compatible with current shape: \[2\]" + tp.TripyException, match=r"New parameter shape: \[2, 3\] is not compatible with current shape: \[2\]" ): network.param = tp.Parameter(tp.ones((2, 3))) diff --git a/tripy/tests/frontend/test_shape.py b/tripy/tests/frontend/test_shape.py index e27504b05..298c73484 100644 --- a/tripy/tests/frontend/test_shape.py +++ b/tripy/tests/frontend/test_shape.py @@ -50,13 +50,19 @@ class TestShapeScalar: np.array(2, dtype=np.int32), ], ) - def test_scalar_shape(self, value): + def test_construction(self, value): s = tp.ShapeScalar(value) assert isinstance(s, tp.ShapeScalar) assert s.trace_tensor.producer.inputs == [] - def test_scalar_shape_str_method(self): + def test_int_conversion(self): + val = 4 + s = tp.ShapeScalar(val) + + assert int(s) == val + + def test_str_method(self): s = tp.ShapeScalar(12) assert s.__str__() == f"shape_scalar(12)" diff --git a/tripy/tests/frontend/test_tensor.py b/tripy/tests/frontend/test_tensor.py index a27f2b1f1..8da8c9187 100644 --- a/tripy/tests/frontend/test_tensor.py +++ b/tripy/tests/frontend/test_tensor.py @@ -22,12 +22,12 @@ import numpy as np import pytest import torch +from tests.conftest import DATA_TYPE_TEST_CASES +from tests.helper import NUMPY_TO_TRIPY import tripy as tp -from tests.conftest import DATA_TYPE_TEST_CASES -from tests.helper import NUMPY_TYPES, np_to_tripy_dtype -from tripy.utils.stack_info import SourceInfo from tripy.common.utils import get_element_type +from tripy.utils.stack_info import SourceInfo class TestTensor: @@ -52,12 +52,12 @@ def test_tensor_device(self, kind): assert isinstance(a.trace_tensor.producer, tp.frontend.trace.ops.Storage) assert a.trace_tensor.producer.device.kind == kind - @pytest.mark.parametrize("dtype", NUMPY_TYPES) + @pytest.mark.parametrize("dtype", NUMPY_TO_TRIPY.keys()) def test_dtype_from_numpy(self, dtype): np_array = np.array([1, 2, 3], dtype=dtype) tensor = tp.Tensor(np_array) - tp_dtype = np_to_tripy_dtype(dtype) + tp_dtype = NUMPY_TO_TRIPY[dtype] assert tensor.dtype == tp_dtype def test_bool_tensor(self): diff --git a/tripy/tests/frontend/trace/ops/test_reshape.py b/tripy/tests/frontend/trace/ops/test_reshape.py index 5e960122e..0a9610a35 100644 --- a/tripy/tests/frontend/trace/ops/test_reshape.py +++ b/tripy/tests/frontend/trace/ops/test_reshape.py @@ -55,7 +55,7 @@ def test_incorrect_dims(self): with helper.raises( tp.TripyException, - match="number of output elements \(1\) doesn't match expected number of elements \(4\)", + match=r"number of output elements \(1\) doesn't match expected number of elements \(4\)", has_stack_info_for=[a, b], ): b.eval() diff --git a/tripy/tests/helper.py b/tripy/tests/helper.py index 274daf571..4c44e8e2f 100644 --- a/tripy/tests/helper.py +++ b/tripy/tests/helper.py @@ -102,35 +102,22 @@ def check_mlir(mlir, expected): # Supported NumPy data types -NUMPY_TYPES = [ - np.int8, +NUMPY_TO_TRIPY = { + bool: tp.bool, + np.int8: tp.int8, + np.int32: tp.int32, + np.int64: tp.int64, + np.float16: tp.float16, + np.float32: tp.float32, # np.int16, # TODO(#247): Add support for int16 - np.int32, - np.int64, # np.uint8, # TODO(#247): Add support for uint8 # np.uint16, # TODO(#190): Add support for unsupported MLIR-TensorRT types. # np.uint32, # TODO(#190): Add support for unsupported MLIR-TensorRT types. # np.uint64, # TODO(#190): Add support for unsupported MLIR-TensorRT types. - np.float16, - np.float32, # np.float64, # TODO(#247): Add support for float64 -] - - -def np_to_tripy_dtype(dtype): - return { - bool: tp.bool, - np.int8: tp.int8, - np.int32: tp.int32, - np.int64: tp.int64, - np.float16: tp.float16, - np.float32: tp.float32, - }[dtype] - +} -def torch_type_supported(data: np.ndarray): - unsupported_dtypes = [np.int16, np.uint16, np.uint32, np.uint64] - return data.dtype not in unsupported_dtypes +TRIPY_TO_NUMPY = {v: k for k, v in NUMPY_TO_TRIPY.items()} TORCH_DTYPES = { diff --git a/tripy/tests/integration/test_cast.py b/tripy/tests/integration/test_cast.py index f87536b77..3e5902924 100644 --- a/tripy/tests/integration/test_cast.py +++ b/tripy/tests/integration/test_cast.py @@ -21,7 +21,7 @@ import tripy as tp from tests.conftest import skip_if_older_than_sm89 -from tests.helper import np_to_tripy_dtype +from tests.helper import NUMPY_TO_TRIPY class TestCast: @@ -50,8 +50,8 @@ class TestCast: ], ) def test_cast(self, input_dtype, target_dtype): - tp_input_dtype = np_to_tripy_dtype(input_dtype) - tp_target_dtype = np_to_tripy_dtype(target_dtype) + tp_input_dtype = NUMPY_TO_TRIPY[input_dtype] + tp_target_dtype = NUMPY_TO_TRIPY[target_dtype] # TODO(#222): Integer casts with negative numbers fail in many cases input_tensor = tp.Tensor([0, 1, 2], dtype=tp_input_dtype) @@ -71,7 +71,7 @@ def test_cast_quantized_dtypes_into_bool(self, source_dtype): @pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int64, np.int8]) def test_cast_from_bool(self, target_dtype): - tp_target_dtype = np_to_tripy_dtype(target_dtype) + tp_target_dtype = NUMPY_TO_TRIPY[target_dtype] # in principle, it is not important what *specific* values we convert to, # so long as false is mapped to 0 and true to nonzero diff --git a/tripy/tests/integration/test_stack.py b/tripy/tests/integration/test_stack.py index 256dab628..cdecc4ae3 100644 --- a/tripy/tests/integration/test_stack.py +++ b/tripy/tests/integration/test_stack.py @@ -57,6 +57,6 @@ def test_stack_different_shapes(self): b = tp.ones((4, 3)) with raises( tp.TripyException, - match="error: shapes of operand \(0\) and \(1\) are not compatible at non-concat index 1:", + match=r"error: shapes of operand \(0\) and \(1\) are not compatible at non-concat index 1:", ): tp.stack([a, b]).eval() diff --git a/tripy/tests/performance/__init__.py b/tripy/tests/performance/__init__.py new file mode 100644 index 000000000..a08b2c204 --- /dev/null +++ b/tripy/tests/performance/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tripy/tests/performance/cases/__init__.py b/tripy/tests/performance/cases/__init__.py new file mode 100644 index 000000000..d1b6ef002 --- /dev/null +++ b/tripy/tests/performance/cases/__init__.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = [] + + +# In order to make the pytest fixtures defined in this submodule visible, we +# need to import them in the test using their function names. To do so, we can +# export them via this file by making them local variables and adding them to `__all__`. +# +# Note that just importing the module is sufficient to update PERF_CASES, but does +# not make the actual fixture function visible to pytest. +def __discover_modules(): + import importlib + import pkgutil + + mods = [importlib.import_module("tests.performance.cases")] + while mods: + mod = mods.pop(0) + + yield mod + + if hasattr(mod, "__path__"): + mods.extend( + [ + importlib.import_module(f"{mod.__name__}.{submod.name}") + for submod in pkgutil.iter_modules(mod.__path__) + ] + ) + + +modules = list(__discover_modules())[1:] + +# Discover and import all perf fixtures. +from tests.performance.conftest import PERF_CASES + +__perf_case_names = {case.name for case in PERF_CASES} + +for mod in modules: + for name, obj in mod.__dict__.items(): + if name in __perf_case_names: + locals()[name] = obj + __all__.append(name) diff --git a/tripy/tests/performance/cases/linear_block.py b/tripy/tests/performance/cases/linear_block.py new file mode 100644 index 000000000..1cd74746f --- /dev/null +++ b/tripy/tests/performance/cases/linear_block.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch +from tests.helper import TRIPY_TO_NUMPY +from tests.performance.conftest import PerfParam, perf_fixture + +import tripy as tp + + +@perf_fixture( + params=[ + PerfParam(tp.float32, 1.25), + PerfParam(tp.float16), + ] +) +def linear_block(tripy_dtype, torch_dtype): + + NUM_LAYERS = 15 + + class LinearBlock(tp.Module): + def __init__(self): + self.layers = [tp.Linear(256, 256, bias=False, dtype=tripy_dtype) for _ in range(NUM_LAYERS)] + for layer in self.layers: + # Adjust the weights to prevent FP16 overflows: + weight = np.tile(np.array([[-1, 1], [1, -1]], dtype=TRIPY_TO_NUMPY[tripy_dtype]), (128, 128)) + layer.weight = tp.Parameter(weight) + + def __call__(self, input): + for layer in self.layers: + input = layer(input) + return input + + class TorchLinearBlock(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [ + torch.nn.Linear(256, 256, bias=False, dtype=torch_dtype, device=torch.device("cuda")) + for _ in range(NUM_LAYERS) + ] + ) + + def forward(self, input): + for layer in self.layers: + input = layer(input) + return input + + tripy_block = LinearBlock() + torch_block = TorchLinearBlock() + input_infos = {"input": tp.InputInfo(shape=(1024, 256), dtype=tripy_dtype)} + return tripy_block, torch_block, input_infos diff --git a/tripy/tests/performance/conftest.py b/tripy/tests/performance/conftest.py new file mode 100644 index 000000000..b405f733b --- /dev/null +++ b/tripy/tests/performance/conftest.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List + +import pytest +import torch +from tests import helper + +import tripy as tp + +PERF_CASES = [] + + +@dataclass +class PerfParam: + dtype: tp.dtype + """Data type to use""" + perf_threshold: float = 1.05 + """ + A multiplier indicating how much faster Tripy should be compared to Torch. + For example, 1.05 would mean that Tripy should be 5% faster than Torch. + """ + + +def perf_fixture(params: List[PerfParam]): + + def perf_fixture_impl(func): + PERF_CASES.append(pytest.lazy_fixture(func.__qualname__)) + + @pytest.fixture(params=params, scope="session", ids=lambda param: param.dtype) + def wrapped(request): + dtype, perf_threshold = request.param.dtype, request.param.perf_threshold + tripy_module, torch_module, input_infos = func(dtype, helper.TORCH_DTYPES[dtype]) + + torch_state_dict = {key: torch.from_dlpack(value) for key, value in tripy_module.state_dict().items()} + torch_module.load_state_dict(torch_state_dict) + + tripy_compiled = tp.compile(tripy_module, kwargs=input_infos) + + inputs = {key: tp.iota(input_info.shape_bounds.opt, dtype=dtype) for key, input_info in input_infos.items()} + for tensor in inputs.values(): + tensor.eval() + + torch_compiled = torch.compile(torch_module) + + return tripy_compiled, torch_compiled, inputs, perf_threshold + + return wrapped + + return perf_fixture_impl diff --git a/tripy/tests/performance/test_perf.py b/tripy/tests/performance/test_perf.py new file mode 100644 index 000000000..0a4cf9257 --- /dev/null +++ b/tripy/tests/performance/test_perf.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from textwrap import dedent + +import pytest +import torch + +# Need to import cases in order to populate PERF_CASES and load pytest fixtures +from tests.performance.cases import * +from tests.performance.conftest import PERF_CASES + +import tripy as tp + + +@pytest.mark.parametrize("perf_case", PERF_CASES) +def test_perf_regression(perf_case, benchmark): + compiled_tripy_module, _, inputs, _ = perf_case + + def run_inference(): + compiled_tripy_module(**inputs) + compiled_tripy_module.stream.synchronize() + + benchmark(run_inference) + + +@pytest.mark.parametrize("perf_case", PERF_CASES) +def test_perf_comparative(perf_case): + compiled_tripy_module, compiled_torch_module, inputs, perf_threshold = perf_case + + WARM_UP_RUNS = 10 + ITERATIONS = 250 + + # Time Tripy + stream = tp.default_stream() + + for _ in range(WARM_UP_RUNS): + compiled_tripy_module(**inputs) + stream.synchronize() + + start = time.perf_counter() + for _ in range(ITERATIONS): + tripy_out = compiled_tripy_module(**inputs) + stream.synchronize() + end = time.perf_counter() + + # Torch will report time in ms: + tripy_time = (end - start) * 1000 + + # Time Torch + torch_inputs = {key: torch.from_dlpack(value).to(device="cuda") for key, value in inputs.items()} + + with torch.no_grad(): + for _ in range(WARM_UP_RUNS): + compiled_torch_module(**torch_inputs) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(ITERATIONS): + torch_out = compiled_torch_module(**torch_inputs) + end.record() + torch.cuda.synchronize() + + torch_time = start.elapsed_time(end) + + # If the outputs don't match, then we're either not comparing apples-to-apples + # or there is an accuracy bug somewhere - either way we want to catch it here. + assert torch.allclose(torch_out, torch.from_dlpack(tripy_out)) + + print(f"Tripy was {torch_time / float(tripy_time)}x faster than Torch") + assert (tripy_time * perf_threshold) < torch_time + + +def test_tripy_overhead(): + def measure_overhead(num_io, warm_up_runs=10, iterations=1000): + """ + Returns the overhead introduced by Tripy code for the specified number + of input/output tensors of a function in microseconds. + """ + assert num_io > 0 + + arg_str = ", ".join(f"arg{num}" for num in range(num_io)) + exec( + dedent( + f""" + def func({arg_str}): + return {arg_str} + """ + ), + locals(), + globals(), + ) + + # By using an empty shape, we ensure that we are measuring nothing + # except Tripy Python overheads. + SHAPE = (0,) + compiled_one_io = tp.compile(func, args=[tp.InputInfo(shape=SHAPE, dtype=tp.float32) for _ in range(num_io)]) + + inputs = [tp.iota(shape=SHAPE, dtype=tp.float32) for _ in range(num_io)] + for input in inputs: + input.eval() + + for _ in range(warm_up_runs): + compiled_one_io(*inputs) + + start = time.perf_counter_ns() + for _ in range(iterations): + compiled_one_io(*inputs) + end = time.perf_counter_ns() + + return (end - start) / (iterations * 1000.0) + + assert measure_overhead(1) < 60.0 + + # Check that the overhead increases at most linearly as we increase number of I/O tensors. + overheads = [measure_overhead(i) for i in range(3, 10)] + deltas = [n - p for p, n in zip(overheads[:-1], overheads[1:])] + print(f"overheads: {overheads}") + print(f"deltas: {deltas}") + assert all(delta < 45 for delta in deltas) + + # Ensure all deltas are within a few microseconds of each other + average_delta = sum(deltas) / float(len(deltas)) + assert all(abs(delta - average_delta) < 10 for delta in deltas) diff --git a/tripy/tests/test_function_registry.py b/tripy/tests/test_function_registry.py index c9e03f348..28681a5ca 100644 --- a/tripy/tests/test_function_registry.py +++ b/tripy/tests/test_function_registry.py @@ -189,7 +189,7 @@ def test_missing_arguments_gives_sane_error(self, registry): def func(a: int, b: int): return a + b - with helper.raises(TripyException, match="Some required arguments were not provided: \['a'\]"): + with helper.raises(TripyException, match=r"Some required arguments were not provided: \['a'\]"): registry["test"](b=0) def test_func_overload_caches_signature(self, registry): diff --git a/tripy/tripy/backend/api/executable.py b/tripy/tripy/backend/api/executable.py index c50f01712..16de171f7 100644 --- a/tripy/tripy/backend/api/executable.py +++ b/tripy/tripy/backend/api/executable.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tripy.backend.api.stream import default_stream -from tripy.backend.api.input_info import ArgInfo import base64 import inspect from typing import Sequence, Union @@ -21,11 +19,13 @@ import mlir_tensorrt.runtime.api as runtime from tripy import export +from tripy.backend.api.input_info import ArgInfo from tripy.backend.mlir import Executor from tripy.backend.mlir import utils as mlir_utils from tripy.common.exception import raise_error from tripy.frontend import Tensor from tripy.utils import json as json_utils +from tripy.utils.stack_info import StackInfo @export.public_api(document_under="compiling_code") @@ -42,9 +42,9 @@ def __init__(self, executable, arg_names, output_devices): self._executable = executable self._executor = Executor(self._executable) self._arg_names = arg_names + self._num_expected_args = len(arg_names) self._output_devices = output_devices self._executable_signature = self._executable.get_signature("main") - self._stream = default_stream() # Build a signature so the executable works with `inspect.signature` params = [] @@ -57,12 +57,11 @@ def __init__(self, executable, arg_names, output_devices): @property def stream(self): - return self._stream + return self._executor.stream @stream.setter def stream(self, stream): - self._stream = stream - self._executor.stream = self._stream + self._executor.stream = stream def __call__(self, *args, **kwargs) -> Union[Tensor, Sequence[Tensor]]: """ @@ -91,13 +90,12 @@ def add(a, b): out = compiled_add(a, b) """ - NUM_ARGS = len(args) + len(kwargs) + num_positional = len(args) + NUM_ARGS = num_positional + len(kwargs) - input_tensors = [] - - input_tensors.extend(args) + input_tensors = list(args) # Need to get arguments in the order of self._arg_names, which may be different from kwargs ordering. - expected_kwargs = self._arg_names[len(args) :] + expected_kwargs = self._arg_names[num_positional:] for name in expected_kwargs: if name not in kwargs: raise_error(f"Missing argument: {name}", [f"Expected the following arguments: {self._arg_names}"]) @@ -110,16 +108,17 @@ def add(a, b): f"Extra keyword arguments: {list(kwargs.keys())}", [ f"Expected the following arguments: {self._arg_names}.\n" - f"Note: The following arguments were already provided as positional arguments: {self._arg_names[:len(args)]}" + f"Note: The following arguments were already provided as positional arguments: {self._arg_names[:num_positional]}" ], ) # We do this after kwarg checks since those will be more informative (we can explain which arguments are missing/extra). - if NUM_ARGS != len(self._arg_names): + + if NUM_ARGS != self._num_expected_args: raise_error( "Incorrect number of arguments.", [ - f"Expected {len(self._arg_names)} arguments but got {NUM_ARGS}.\n" + f"Expected {self._num_expected_args} arguments but got {NUM_ARGS}.\n" f"Note: Expected arguments were: {self._arg_names}", ], ) diff --git a/tripy/tripy/backend/api/input_info.py b/tripy/tripy/backend/api/input_info.py index 457de0fb3..157fa7efb 100644 --- a/tripy/tripy/backend/api/input_info.py +++ b/tripy/tripy/backend/api/input_info.py @@ -28,9 +28,7 @@ class InputInfo: Captures information about an input to a compiled function. """ - def __init__( - self, shape: Sequence[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]], dtype: "tripy.dtype" - ) -> None: + def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tripy.dtype") -> None: """ Args: shape: The shape of the input. diff --git a/tripy/tripy/backend/api/stream.py b/tripy/tripy/backend/api/stream.py index 79160f2c6..92099cf9d 100644 --- a/tripy/tripy/backend/api/stream.py +++ b/tripy/tripy/backend/api/stream.py @@ -105,10 +105,6 @@ def synchronize(self) -> None: def __eq__(self, other): if not isinstance(other, Stream): return False - - if not (hasattr(self, "_active_cuda_stream") and hasattr(other, "_active_cuda_stream")): - return False - return self._active_cuda_stream == other._active_cuda_stream def __str__(self): diff --git a/tripy/tripy/backend/mlir/executor.py b/tripy/tripy/backend/mlir/executor.py index 95cc98233..71bdf3932 100644 --- a/tripy/tripy/backend/mlir/executor.py +++ b/tripy/tripy/backend/mlir/executor.py @@ -17,31 +17,34 @@ from typing import List -import mlir_tensorrt.compiler.api as compiler import mlir_tensorrt.runtime.api as runtime +from tripy.backend.api.stream import default_stream from tripy.backend.mlir.memref import create_empty_memref +from tripy.backend.mlir.utils import MLIRRuntimeClient, convert_runtime_dtype_to_tripy_dtype from tripy.backend.utils import TensorInfo from tripy.common import datatype, device from tripy.common.exception import raise_error -from tripy.utils import log_time, make_tuple +from tripy.common.utils import convert_list_to_array +from tripy.utils import make_tuple class Executor: def __init__(self, executable: runtime.Executable) -> None: - from tripy.backend.mlir.utils import MLIRRuntimeClient - from tripy.backend.api.stream import default_stream - self.runtime_client = MLIRRuntimeClient() session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) self.session = runtime.RuntimeSession(session_options, executable) self.device = self.runtime_client.get_devices()[0] # Assume a single device is available. self.signature = executable.get_signature("main") self.stream = default_stream() + self.num_input_args = self.signature.get_num_input_args() + self.num_output_args = self.signature.get_num_output_args() + self.output_args = [ + self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args) + ] + self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args] def _create_shape_memref(self, shape): - from tripy.common.utils import convert_list_to_array - shape = make_tuple(shape) if len(shape) == 0: # create an empty memref @@ -55,34 +58,21 @@ def _create_shape_memref(self, shape): stream=self.stream._active_cuda_stream, ) - def _get_inputs_runtime_shape(self, inputs): - inputs_shape = [] - for input in inputs: - inputs_shape.append(input.trace_tensor.producer.data.shape) - return inputs_shape - def _get_outputs_shape(self): - offset = self.signature.get_num_input_args() outputs_shape = [] all_outputs_known = True - for output_index in range(self.signature.get_num_output_args()): - arg_index = output_index + offset - arg = self.signature.get_arg(arg_index) - assert compiler.MemRefType.isinstance(arg) - memref = runtime.MemRefType(arg) - rank = len(memref.shape) - + for memref in self.output_memrefs: outputs_shape.append(memref.shape) - if rank > 0: - all_outputs_known &= all(dim >= 0 for dim in memref.shape) + all_outputs_known &= all(dim >= 0 for dim in memref.shape) return outputs_shape, all_outputs_known - def _execute_shape_inference(self, inputs_shape, outputs_shape): - # Only execute shape inference if shape function name is valid. - assert ( - self.signature.get_shape_func_name() - ), f"Shape inference function is missing while output shapes are not known." + def _get_inputs_runtime_shape(self, inputs): + inputs_shape = [] + for input in inputs: + inputs_shape.append(input.trace_tensor.producer.data.shape) + return inputs_shape + def _execute_shape_inference(self, inputs_shape, outputs_shape): inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape] outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape] self.session.execute_function( @@ -93,41 +83,24 @@ def _execute_shape_inference(self, inputs_shape, outputs_shape): return outputs_runtime_shape def _get_output_tensor_info(self, outputs_runtime_shape, output_devices): - from tripy.backend.mlir.utils import convert_runtime_dtype_to_tripy_dtype - - offset = self.signature.get_num_input_args() outputs_tensor_info = [] - for output_index in range(self.signature.get_num_output_args()): - arg_index = output_index + offset - arg = self.signature.get_arg(arg_index) - assert compiler.MemRefType.isinstance(arg) or compiler.ScalarType.isinstance( - arg - ), "Argument must be either MemRefType or ScalarType" - assert compiler.MemRefType.isinstance( - arg - ), "ScalarType argument are not yet supported" # 158: Add scalar type output argument support. - memref = compiler.MemRefType(arg) + for index in range(self.num_output_args): + memref = self.output_memrefs[index] dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype) - device_type = "gpu" if memref.address_space == runtime.PointerType.device else "cpu" - if output_devices[output_index]: - device_type = output_devices[output_index].kind - is_static_shape = all(dim >= 0 for dim in memref.shape) - if is_static_shape: - outputs_tensor_info.append( - TensorInfo(len(memref.shape), tuple(memref.shape), dtype, device(device_type)) - ) - else: - runtime_shape = [ - rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[output_index]) - ] - outputs_tensor_info.append( - TensorInfo( - len(runtime_shape), - tuple(runtime_shape), - dtype, - device(device_type), - ) + + output_device = output_devices[index] + if not output_device: + output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0)) + + runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])] + outputs_tensor_info.append( + TensorInfo( + len(runtime_shape), + tuple(runtime_shape), + dtype, + output_device, ) + ) return outputs_tensor_info def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]): @@ -138,21 +111,9 @@ def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]): output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices) return output_tensor_info - @property - def stream(self): - return self._stream - - @stream.setter - def stream(self, stream): - self._stream = stream - - @log_time - def execute(self, output_devices=List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]: - from tripy.frontend.trace.ops import Storage - + def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]: in_args = [] for inp in inputs: - assert isinstance(inp.trace_tensor.producer, Storage) and inp.trace_tensor.producer.has_memref memref = inp.trace_tensor.producer.data # HACK (#155): MLIR-TensorRT requires inputs to be on device. # Remove explicit copy to device once #155 is addressed. diff --git a/tripy/tripy/backend/mlir/memref.py b/tripy/tripy/backend/mlir/memref.py index 5b6da67a6..73f5f9f55 100644 --- a/tripy/tripy/backend/mlir/memref.py +++ b/tripy/tripy/backend/mlir/memref.py @@ -41,7 +41,7 @@ def _cached_create_empty_memref(shape: Sequence[int], dtype: str, device_kind: s def create_empty_memref( shape: Sequence[int], dtype: str, - device: tp_device = tp_device("gpu"), + device: tp_device = tp_device(("gpu", 0)), stream=None, use_cache: bool = True, ): diff --git a/tripy/tripy/backend/mlir/utils.py b/tripy/tripy/backend/mlir/utils.py index facf5a363..552987169 100644 --- a/tripy/tripy/backend/mlir/utils.py +++ b/tripy/tripy/backend/mlir/utils.py @@ -288,15 +288,17 @@ def redirect_stderr() -> BinaryIO: def convert_tripy_dtype_to_runtime_dtype(dtype: datatype.dtype) -> runtime.ScalarTypeCode: - if dtype not in TRIPY_DTYPE_TO_MLIR_TRT: + try: + return TRIPY_DTYPE_TO_MLIR_TRT[dtype] + except KeyError: raise_error(f"Data type: '{dtype}' does not have a corresponding runtime data type") - return TRIPY_DTYPE_TO_MLIR_TRT.get(dtype) def convert_runtime_dtype_to_tripy_dtype(dtype: runtime.ScalarTypeCode) -> datatype.dtype: - if dtype not in MLIR_TRT_TO_TRIPY_DTYPE: + try: + return MLIR_TRT_TO_TRIPY_DTYPE[dtype] + except KeyError: raise_error(f"Data type: '{dtype}' does not have a corresponding tripy data type") - return MLIR_TRT_TO_TRIPY_DTYPE.get(dtype) def is_any_dim_dynamic(mlir_tensor): diff --git a/tripy/tripy/common/device.py b/tripy/tripy/common/device.py index da0882c4a..3f8298c01 100644 --- a/tripy/tripy/common/device.py +++ b/tripy/tripy/common/device.py @@ -22,6 +22,8 @@ from tripy.common.exception import TripyException from tripy.utils.json import Decoder, Encoder +_VALID_KINDS = {"cpu", "gpu"} + @export.public_api() @dataclass @@ -60,27 +62,30 @@ def __init__(self, device: str) -> None: assert gpu_1.kind == "gpu" assert gpu_1.index == 1 """ - - kind, _, index = device.partition(":") - kind = kind.lower() - - if index: - try: - index = int(index) - except ValueError: - raise TripyException(f"Could not interpret: {index} as an integer") - else: - index = 0 - - if index < 0: - raise TripyException(f"Device index must be a non-negative integer, but was: {index}") - - VALID_KINDS = {"cpu", "gpu"} - if kind not in VALID_KINDS: - raise TripyException(f"Unrecognized device kind: {kind}. Choose from: {list(VALID_KINDS)}") - - self.kind = kind - self.index = index + try: + # Fast constructor for the critical path. If a Tuple[str, int] is provided, then + # we bypass all the logic to parse the information from a string. + self.kind, self.index = device + except ValueError: + kind, _, index = device.partition(":") + kind = kind.lower() + + if index: + try: + index = int(index) + except ValueError: + raise TripyException(f"Could not interpret: {index} as an integer") + else: + index = 0 + + if index < 0: + raise TripyException(f"Device index must be a non-negative integer, but was: {index}") + + if kind not in _VALID_KINDS: + raise TripyException(f"Unrecognized device kind: {kind}. Choose from: {list(_VALID_KINDS)}") + + self.kind = kind + self.index = index def __str__(self) -> str: return f"{self.kind}:{self.index}" diff --git a/tripy/tripy/flat_ir/flat_ir.py b/tripy/tripy/flat_ir/flat_ir.py index 60ee941d4..d84aa2508 100644 --- a/tripy/tripy/flat_ir/flat_ir.py +++ b/tripy/tripy/flat_ir/flat_ir.py @@ -18,10 +18,11 @@ from typing import Dict, List, Sequence, Set, Union from mlir_tensorrt.compiler.dialects._ods_common import get_op_result_or_value +from mlir_tensorrt.runtime.api import MemRefValue + from tripy import utils from tripy.common.shape_bounds import ShapeBounds - -from tripy.flat_ir.ops import ConstantOp +from tripy.utils.utils import list_to_tuple class FlatIR: @@ -96,6 +97,7 @@ def to_mlir(self): """ from mlir_tensorrt.compiler import ir from mlir_tensorrt.compiler.dialects import func as func_dialect + from tripy.backend.mlir.utils import make_ir_context, make_tensor_location from tripy.flat_ir.function import FlatIRFunction from tripy.flat_ir.ops.base import BaseFlatIROp @@ -206,7 +208,7 @@ def _process_function_body( def _get_op_inputs(op: ir.Operation, mlir_tensor_map: Dict[str, ir.Value]) -> List[ir.Value]: """Get the inputs for an operation, casting to dynamic tensors if necessary.""" - from tripy.backend.mlir.utils import is_any_dim_dynamic, cast_to_dynamic_ranked_tensor + from tripy.backend.mlir.utils import cast_to_dynamic_ranked_tensor, is_any_dim_dynamic from tripy.flat_ir.ops import DynamicBroadcastOp, DynamicReshapeOp op_inputs = [] @@ -352,9 +354,6 @@ def register_tensor(self, tensor: "FlatIRTensor") -> "FlatIRTensor": return tensor def _get_constant_key(self, op): - from mlir_tensorrt.runtime.api import MemRefValue - from tripy.utils.utils import list_to_tuple - if isinstance(op.data, MemRefValue): # use data pointer as key when data is a memref, # usually come from users, no need to deduplicate @@ -372,9 +371,9 @@ def integrate_subgraph(self, inputs: List["FlatIRTensor"], outputs: List["FlatIR Integrate a subgraph delineated by the given inputs and outputs into this FlatIR. """ from tripy.flat_ir.function import FlatIRFunction + from tripy.flat_ir.ops import ConstantOp from tripy.flat_ir.ops.base import BaseFlatIROp from tripy.flat_ir.tensor import FlatIRTensor - from tripy.flat_ir.ops import ConstantOp seen_tensors: Set[int] = set() dedup_func_op_map: Dict[int, List[FlatIRFunction]] = {} diff --git a/tripy/tripy/frontend/module/conv.py b/tripy/tripy/frontend/module/conv.py index c194fc256..5a1a53735 100644 --- a/tripy/tripy/frontend/module/conv.py +++ b/tripy/tripy/frontend/module/conv.py @@ -77,6 +77,7 @@ def __init__( self.stride = utils.default(stride, (1,) * (rank - 2)) self.dilation = utils.default(dilation, (1,) * (rank - 2)) + self.bias = None if bias: self.bias = DefaultParameter((out_channels,), dtype=dtype) @@ -268,7 +269,7 @@ def __call__(self, input: "tripy.Tensor") -> "tripy.Tensor": None, # lhs_dilation for transposed conv only self.dilation, ) - if hasattr(self, "bias"): + if self.bias is not None: bias_shape_to_broadcast = (1, self.weight.shape[0]) + (1,) * (self.weight.rank - 2) x += reshape(self.bias, bias_shape_to_broadcast) return x diff --git a/tripy/tripy/frontend/module/conv_transpose.py b/tripy/tripy/frontend/module/conv_transpose.py index ec9da8687..0357898fb 100644 --- a/tripy/tripy/frontend/module/conv_transpose.py +++ b/tripy/tripy/frontend/module/conv_transpose.py @@ -238,7 +238,7 @@ def __call__(self, input: "tripy.Tensor") -> "tripy.Tensor": self.stride, # effectively lhs_dilation for StableHLO self.dilation, ) - if hasattr(self, "bias"): + if self.bias is not None: bias_shape_to_broadcast = (1, weight.shape[0]) + (1,) * (rank - 2) x += reshape(self.bias, bias_shape_to_broadcast) return x diff --git a/tripy/tripy/frontend/module/linear.py b/tripy/tripy/frontend/module/linear.py index 3af144a7e..0802dd9c8 100644 --- a/tripy/tripy/frontend/module/linear.py +++ b/tripy/tripy/frontend/module/linear.py @@ -92,14 +92,14 @@ def __init__( # Replace with random weights when #74 is completed. self.weight = DefaultParameter((out_features, in_features), dtype=dtype) + self.bias = None if bias: self.bias = DefaultParameter((out_features,), dtype=dtype) self.quant_dtype = quant_dtype self.weight_quant_dim = weight_quant_dim - if quant_dtype is not None: - self.weight_scale = None - self.input_scale = None + self.weight_scale = None + self.input_scale = None def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor": r""" @@ -136,7 +136,7 @@ def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor": weight = self.weight out = x @ (transpose(weight, 1, 0)) - if hasattr(self, "bias"): + if self.bias is not None: out = out + unsqueeze(self.bias, 0) return out diff --git a/tripy/tripy/frontend/module/module.py b/tripy/tripy/frontend/module/module.py index e920354dc..c69b771bc 100644 --- a/tripy/tripy/frontend/module/module.py +++ b/tripy/tripy/frontend/module/module.py @@ -19,8 +19,8 @@ import operator from typing import Any, Dict, Iterator, List, Tuple, Union, Sequence, TypeVar -from tripy import export -from tripy.common.exception import raise_error +from tripy import export, utils +from tripy.common.exception import raise_error, _make_stack_info_message from tripy.frontend.module.parameter import Parameter from tripy.logging import logger @@ -41,11 +41,11 @@ def _check_param_compatible(original_param, new_param, param_name): def _is_homogeneous_container(container: Sequence, typ: T): - return all(isinstance(op, typ) for op in container) + return all(isinstance(elem, typ) for elem in container) def _contains_types(container: Sequence, types: type): - return any(any(isinstance(value, typ) for typ in types) for value in container) + return any(any(isinstance(elem, typ) for typ in types) for elem in container) @export.public_api(document_under="modules/index.rst") @@ -106,8 +106,17 @@ def __setattr__(self, name: str, value: Any) -> None: if isinstance(value, List) or isinstance(value, Dict): container = value if isinstance(value, List) else value.values() - if _contains_types(container, [Parameter, Module]) and not _is_homogeneous_container(container, Parameter): - logger.warning("A container of mixed types will not be registered with this module's state_dict().") + if _contains_types(container, [Parameter, Module]) and ( + not _is_homogeneous_container(container, Parameter) and not _is_homogeneous_container(container, Module) + ): + stack_info = utils.get_stack_info() + stack_info.fetch_source_code() + stack_info_msg = _make_stack_info_message(stack_info) + + logger.warning( + "A container of mixed types will not be registered with this module's state_dict()." + + (f"\nNote: container was set here: {stack_info_msg}" if stack_info_msg else "") + ) def state_dict(self) -> Dict[str, Parameter]: r""" @@ -276,14 +285,10 @@ def _iterate_members_of_type(self, typ: T) -> Iterator[Tuple[str, T]]: for name, value in vars(self).items(): if isinstance(value, typ): yield name, value - elif isinstance(value, List) and _contains_types(value, [typ]) and _is_homogeneous_container(value, typ): + elif isinstance(value, List) and _is_homogeneous_container(value, typ): for i, obj in enumerate(value): yield f"{name}.{i}", obj - elif ( - isinstance(value, Dict) - and _contains_types(value.values(), [typ]) - and _is_homogeneous_container(value.values(), typ) - ): + elif isinstance(value, Dict) and _is_homogeneous_container(value.values(), typ): for key, obj in value.items(): yield f"{name}.{key}", obj diff --git a/tripy/tripy/frontend/module/parameter.py b/tripy/tripy/frontend/module/parameter.py index 7568a7124..d883bef91 100644 --- a/tripy/tripy/frontend/module/parameter.py +++ b/tripy/tripy/frontend/module/parameter.py @@ -15,9 +15,10 @@ # limitations under the License. # +import math from typing import Any, Sequence -from tripy import export, utils +from tripy import export from tripy.frontend.tensor import Tensor from tripy.utils import Result @@ -79,7 +80,7 @@ def __init__(self, shape: Sequence[int], dtype: "tripy.dtype") -> None: from tripy.frontend.ops.tensor_initializers import arange from tripy.frontend.trace.ops.reshape import reshape - super().__init__(reshape(arange(utils.volume(shape), dtype), shape)) + super().__init__(reshape(arange(math.prod(shape), dtype), shape)) self._shape = shape self._dtype = dtype diff --git a/tripy/tripy/frontend/shape.py b/tripy/tripy/frontend/shape.py index 56b94bf80..9b38d09bc 100644 --- a/tripy/tripy/frontend/shape.py +++ b/tripy/tripy/frontend/shape.py @@ -17,11 +17,11 @@ from typing import Any, Optional, Sequence, Union +import tripy.frontend.utils as frontend_utils from tripy import export, utils from tripy.common.datatype import int32 from tripy.common.exception import raise_error from tripy.frontend.tensor import Tensor -import tripy.frontend.utils as frontend_utils @export.public_api(document_under="shape/index.rst") @@ -203,13 +203,12 @@ def __radd__(self, other): def __mul__(self, other): from tripy.frontend.trace.ops.binary_elementwise import maximum from tripy.frontend.trace.ops.expand import expand - from tripy.frontend.trace.ops.reshape import reshape, flatten + from tripy.frontend.trace.ops.reshape import flatten, reshape from tripy.frontend.trace.ops.unsqueeze import unsqueeze # We unsqueeze self into shape [1, len(self)], so giving [other, len(self)] as # the argument to expand will result in a shape of [other, len(self)] by # copying self the correct number of times. - # Only defined with a scalar argument if not isinstance(other, Tensor): # note: Python does not accept floats as arguments for list multiplication either @@ -334,6 +333,9 @@ def __init__( ) super().__init__(data=data, dtype=int32, name=name, device=device) + def __int__(self) -> int: + return self.tolist() + def __repr__(self) -> str: # denote the representation as a shape rather than a tensor tensor_repr = super().__repr__() diff --git a/tripy/tripy/frontend/tensor.py b/tripy/tripy/frontend/tensor.py index e46d886ed..b22bb7e1f 100644 --- a/tripy/tripy/frontend/tensor.py +++ b/tripy/tripy/frontend/tensor.py @@ -29,6 +29,7 @@ from tripy.common.exception import raise_error from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY from tripy.frontend.trace.ops import Storage +from tripy.frontend.trace.tensor import TraceTensor from tripy.utils.stack_info import StackInfo @@ -96,7 +97,6 @@ def __init__( tensor = tp.Tensor([1.0, 2.0, 3.0], dtype=tp.float32) """ - from tripy.frontend.trace.tensor import TraceTensor stack_info = StackInfo([]) if fetch_stack_info: @@ -122,17 +122,6 @@ def __init__( else: Storage.build_internal([], [self.trace_tensor], data, dtype, device) - # Storage should populate attrs of trace_tensor - assert all( - attr is not None - for attr in [ - self.trace_tensor.shape, - self.trace_tensor.dtype, - self.trace_tensor.device, - self.trace_tensor.producer, - ] - ) - # Explicit cast if necessary # TODO(#155): Add copy as well when host allocation is fixed # Also make device as a property, similar to dtype @@ -174,13 +163,15 @@ def rank(self): return self.trace_tensor.rank def eval(self) -> runtime.MemRefValue: + if isinstance(self.trace_tensor.producer, Storage) and self.trace_tensor.producer.has_memref: + # Exit early if the tensor has already been evaluated. + # This happens before the imports below so we don't incur extra overhead. + return self.trace_tensor.producer.data + from tripy.backend.mlir.compiler import Compiler from tripy.backend.mlir.executor import Executor from tripy.frontend.trace import Trace - if isinstance(self.trace_tensor.producer, Storage) and self.trace_tensor.producer.has_memref: - return self.trace_tensor.producer.data - trace = Trace([self]) flat_ir = trace.to_flat_ir() mlir = flat_ir.to_mlir() @@ -191,7 +182,7 @@ def eval(self) -> runtime.MemRefValue: # Upon computing the value of this tensor, we switch it to have a `Storage` # parameter so that it does not need to be computed again. data = executor.execute([out.device for out in flat_ir.outputs]) - executor._stream.synchronize() + executor.stream.synchronize() assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor" data = data[0] # Data is present now. Assign the underlying device type. diff --git a/tripy/tripy/frontend/trace/ops/base.py b/tripy/tripy/frontend/trace/ops/base.py index cdaae3982..36d5e3130 100644 --- a/tripy/tripy/frontend/trace/ops/base.py +++ b/tripy/tripy/frontend/trace/ops/base.py @@ -17,9 +17,10 @@ import abc from dataclasses import dataclass -from typing import List, Set, Union, Optional +from typing import List, Optional, Set, Union from tripy import utils +from tripy.common.exception import raise_error from tripy.utils import Result @@ -47,10 +48,6 @@ def build_internal( *args and **kwargs are passed along to the trace operation's constructor. """ - from tripy.frontend.trace.tensor import TraceTensor - - assert all(isinstance(tensor, TraceTensor) for tensor in inputs + outputs) - op = cls(inputs, outputs, *args, **kwargs) for out in op.outputs: out.producer = op @@ -73,7 +70,6 @@ def build(cls, inputs: List["Tensor"], *args, num_outputs=1, **kwargs) -> Union[ of returning a list of output tensors. """ - from tripy.common.exception import raise_error from tripy.frontend.shape import Shape, ShapeScalar from tripy.frontend.tensor import Tensor diff --git a/tripy/tripy/frontend/trace/ops/fill.py b/tripy/tripy/frontend/trace/ops/fill.py index fa03efbc9..2ecf259d9 100644 --- a/tripy/tripy/frontend/trace/ops/fill.py +++ b/tripy/tripy/frontend/trace/ops/fill.py @@ -20,12 +20,11 @@ from typing import Optional, Sequence, Union import tripy.frontend.trace.ops.utils as op_utils -from tripy import export, utils, constraints +import tripy.frontend.utils as frontend_utils +from tripy import constraints, export, utils from tripy.common import datatype -from tripy.frontend import utils as frontend_utils from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp -from tripy.common.datatype import DATA_TYPES @dataclass(repr=False) @@ -42,7 +41,7 @@ def infer_dtypes(self): def infer_devices(self): from tripy.common import device - self.outputs[0].device = device("gpu") + self.outputs[0].device = device(("gpu", 0)) def infer_rank(self): if self.output_rank is None: diff --git a/tripy/tripy/frontend/trace/ops/gather.py b/tripy/tripy/frontend/trace/ops/gather.py index 731377748..7e19db8d0 100644 --- a/tripy/tripy/frontend/trace/ops/gather.py +++ b/tripy/tripy/frontend/trace/ops/gather.py @@ -55,7 +55,7 @@ def infer_dtypes(self): def infer_devices(self): from tripy.common import device - self.outputs[0].device = device("gpu") + self.outputs[0].device = self.inputs[0].device @frontend_utils.make_function def to_flat_ir(self, inputs, outputs): diff --git a/tripy/tripy/frontend/trace/ops/iota.py b/tripy/tripy/frontend/trace/ops/iota.py index ee3716add..d98306c2f 100644 --- a/tripy/tripy/frontend/trace/ops/iota.py +++ b/tripy/tripy/frontend/trace/ops/iota.py @@ -60,7 +60,7 @@ def infer_dtypes(self): def infer_devices(self): from tripy.common import device - self.outputs[0].device = device("gpu") + self.outputs[0].device = device(("gpu", 0)) def to_flat_ir(self, inputs, outputs): from tripy.flat_ir.ops import DynamicIotaOp diff --git a/tripy/tripy/frontend/trace/ops/storage.py b/tripy/tripy/frontend/trace/ops/storage.py index 2c781ba22..7dc334612 100644 --- a/tripy/tripy/frontend/trace/ops/storage.py +++ b/tripy/tripy/frontend/trace/ops/storage.py @@ -51,22 +51,21 @@ def __init__( self.data = data if isinstance(data, runtime.MemRefValue): - assert not any([dtype, device]), "Internal usage: dtype/device are inherited from memref." self.dtype = mlir_utils.convert_runtime_dtype_to_tripy_dtype(self.data.dtype) self.shape = tuple(data.shape) - self.device = tp_device("gpu") if data.address_space == runtime.PointerType.device else tp_device("cpu") + self.device = tp_device(("gpu" if data.address_space == runtime.PointerType.device else "cpu", 0)) self.has_memref = True elif common_utils.is_empty(data): # special case: empty tensor self.dtype = utils.default(dtype, datatype.float32) self.shape = tuple(utils.get_shape(data)) self.data = memref.create_empty_memref(shape=self.shape, dtype=self.dtype) - self.device = utils.default(device, tp_device("gpu")) + self.device = utils.default(device, tp_device(("gpu", 0))) self.has_memref = True else: self.dtype = dtype if dtype else common_utils.get_element_type(data) self.shape = tuple(utils.get_shape(data)) - self.device = utils.default(device, tp_device("gpu")) + self.device = utils.default(device, tp_device(("gpu", 0))) self.has_memref = False # for storage, we will always consider the result to be an ordinary tensor @@ -90,7 +89,7 @@ def infer_rank(self): def infer_devices(self): # TODO(#155): Fix allocation on host - self.outputs[0].device = tp_device("gpu") + self.outputs[0].device = tp_device(("gpu", 0)) def to_flat_ir(self, inputs, outputs): from tripy.flat_ir.ops import ConstantOp diff --git a/tripy/tripy/frontend/trace/trace.py b/tripy/tripy/frontend/trace/trace.py index 9cbc0fdae..a49b74798 100644 --- a/tripy/tripy/frontend/trace/trace.py +++ b/tripy/tripy/frontend/trace/trace.py @@ -20,7 +20,6 @@ from tripy.common.exception import raise_error from tripy.common.shape_bounds import ShapeBounds -from tripy.frontend.trace.ops import BaseTraceOp from tripy.frontend.trace.tensor import TraceTensor from tripy.frontend.utils import topological_sort from tripy.logging import logger @@ -31,18 +30,6 @@ class Trace: A flattened representation of a computation graph expressed by one or more Tensors. """ - def _infer_tensor_info(self): - """ - Infers basic information, like device, for all tensors in the trace. - """ - - # Compute and cache device information for all tensors - for inp in self.inputs: - inp.producer.infer_devices() - - for op in self.ops: - op.infer_devices() - def __init__( self, tensors: Sequence["tripy.Tensor"], @@ -56,7 +43,7 @@ def __init__( shapes: The shape profile, consisting of min, opt, and max shapes for each input tensors. Must be in the same order as `inputs`. """ - self.ops: List[BaseTraceOp] = [] + self.ops: List["BaseTraceOp"] = [] self.inputs: List[TraceTensor] = [inp.trace_tensor for inp in inputs] self.outputs: List[TraceTensor] = [tensor.trace_tensor for tensor in tensors] self.shapes = shapes @@ -98,9 +85,6 @@ def check_name(tensor): # Reverse the order of the layers so they are topologically sorted self.ops = topological_sort(self.ops) - # Perform shape/dtype/device inference to fill shape information for all tensors. - self._infer_tensor_info() - logger.trace(lambda: f"{self}\n") def __str__(self) -> str: diff --git a/tripy/tripy/frontend/utils.py b/tripy/tripy/frontend/utils.py index 3155fd81c..dc077cfa7 100644 --- a/tripy/tripy/frontend/utils.py +++ b/tripy/tripy/frontend/utils.py @@ -23,9 +23,8 @@ from tripy import utils from tripy.common.exception import raise_error -from tripy.flat_ir.ops import BaseFlatIROp from tripy.flat_ir.function import FlatIRFunction -from tripy.frontend.trace.ops import BaseTraceOp +from tripy.flat_ir.ops import BaseFlatIROp # Try to include correct column offsets for non-tensor arguments. @@ -40,8 +39,8 @@ def _add_column_info_for_non_tensor( list_index=None, TensorType=None, ): - from tripy.frontend.tensor import Tensor from tripy.frontend.shape import Shape + from tripy.frontend.tensor import Tensor from tripy.frontend.trace.ops.cast import cast TensorType = utils.default(TensorType, Tensor) @@ -241,8 +240,8 @@ def impl(func): @functools.wraps(func) def wrapper(*args, **kwargs): - from tripy.frontend.tensor import Tensor from tripy.frontend.shape import Shape, ShapeScalar + from tripy.frontend.tensor import Tensor all_args = utils.merge_function_arguments(func, *args, **kwargs) @@ -464,6 +463,7 @@ def make_function(func): @functools.wraps(func) def wrapped(*args, **kwargs): from tripy.flat_ir.tensor import FlatIRTensor + from tripy.frontend.trace.ops.base import BaseTraceOp # Determine if this is a method or a free function. is_method = inspect.ismethod(func) or (inspect.isfunction(func) and args and isinstance(args[0], BaseTraceOp)) @@ -565,7 +565,7 @@ def get_or_create_cloned_tensor(tensor): return wrapped -def topological_sort(ops: List[Union[BaseTraceOp, BaseFlatIROp]]) -> List[Union[BaseTraceOp, BaseFlatIROp]]: +def topological_sort(ops: List[Union["BaseTraceOp", BaseFlatIROp]]) -> List[Union["BaseTraceOp", BaseFlatIROp]]: """ This utility to topologically sort a graph that can be a Trace or a FlatIR graph. """ diff --git a/tripy/tripy/utils/utils.py b/tripy/tripy/utils/utils.py index 72dae34fd..1fd406a93 100644 --- a/tripy/tripy/utils/utils.py +++ b/tripy/tripy/utils/utils.py @@ -21,6 +21,7 @@ import hashlib import inspect import os +import math import time import typing from typing import Any, List, Sequence, Union @@ -69,9 +70,10 @@ def log_time(func): @functools.wraps(func) def wrapper(*args, **kwargs): - start_time = time.time() + start = time.perf_counter() result = func(*args, **kwargs) - logger.timing(f"{func.__name__} executed in {time.time() - start_time:.4f} seconds") + end = time.perf_counter() + logger.timing(f"{func.__name__} executed in {end - start:.4f} seconds") return result return wrapper @@ -159,23 +161,6 @@ def list_to_tuple(nested_list): ## -def volume(shape): - """ - Computes volume of a tensor shape. - - Args: - shape: The shape of a tensor - - Returns: - Volume of tensor (float) - """ - - volume = 1 - for s in shape: - volume *= s - return volume - - def flatten_list(data): """ Flattens a nested list into a single list. @@ -215,7 +200,7 @@ def get_shape(data): def should_omit_constant_in_str(shape): - return volume(shape) >= constants.CONSTANT_IR_PRINT_VOLUME_THRESHOLD + return math.prod(shape) >= constants.CONSTANT_IR_PRINT_VOLUME_THRESHOLD def get_dataclass_fields(obj: Any, BaseClass: type) -> List[dataclasses.Field]: