Skip to content

Commit 26fd6f0

Browse files
Updates Tripy to work with Python 3.9, moves files into utils
- Updates the `Dockerfile` to use a Python 3.9 container (the lowest version we support) so we can be sure we're not relying on features that are not available there. This unforunately means that we need to drop `lldb` from the container since it depends on Python 3.10. However, `gdb` is still available. - Moves various files into `utils`. We were previously unable to do this because the `__init__.py` file of `utils` would import lots of things, creating circular dependencies. Now, the top-level `utils` module will not import everything, meaning there is no risk of circular dependencies. - Makes some minor changes in the code to make Python 3.9 work (e.g. using `._name` instead of `.__name__`)
1 parent 9b90e4f commit 26fd6f0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+340
-312
lines changed

tripy/Dockerfile

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
FROM ubuntu:22.04
1+
FROM python:3.9
22

33
LABEL org.opencontainers.image.description="Tripy development container"
44

55
WORKDIR /tripy
66

7-
SHELL ["/bin/bash", "-c"]
7+
ENTRYPOINT ["/bin/bash"]
88

99
# Setup user account
1010
ARG uid=1000
@@ -30,21 +30,5 @@ RUN pip install build .[docs,dev,test,build] \
3030
--extra-index-url https://download.pytorch.org/whl \
3131
--extra-index-url https://pypi.nvidia.com
3232

33-
# Installl lldb for debugging purposes in Tripy container.
34-
# The LLVM version should correspond on LLVM_VERSION specified in https://github.com/NVIDIA/TensorRT-Incubator/blob/main/mlir-tensorrt/build_tools/docker/Dockerfile#L30.
35-
ARG LLVM_VERSION=17
36-
ENV LLVM_VERSION=$LLVM_VERSION
37-
ENV LLVM_PACKAGES="lldb-${LLVM_VERSION}"
38-
RUN echo "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-$LLVM_VERSION main" > /etc/apt/sources.list.d/llvm.list && \
39-
echo "deb-src http://apt.llvm.org/jammy/ llvm-toolchain-jammy-$LLVM_VERSION main" >> /etc/apt/sources.list.d/llvm.list && \
40-
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key 2>/dev/null > /etc/apt/trusted.gpg.d/apt.llvm.org.asc && \
41-
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null && \
42-
echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ jammy main' | tee /etc/apt/sources.list.d/kitware.list >/dev/null && \
43-
apt-get update && \
44-
apt-get install -y ${LLVM_PACKAGES} && \
45-
apt-get clean -y && \
46-
rm -rf /var/lib/apt/lists/* && \
47-
ln -s /usr/bin/lldb-17 /usr/bin/lldb
48-
4933
# Export tripy into the PYTHONPATH so it doesn't need to be installed after making changes
5034
ENV PYTHONPATH=/tripy

tripy/docs/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ which specifies doc metadata for each API (e.g. location).
4545
- Docstring must include *at least* **one [code example](#code-examples)**.
4646

4747
- If the function accepts `tp.Tensor`s, must indicate **data type constraints**
48-
with the [`wrappers.interface`](../nvtripy/wrappers.py) decorator.
48+
with the [`wrappers.interface`](../nvtripy/utils/wrappers.py) decorator.
4949

5050
**Example:**
5151

tripy/docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import nvtripy as tp
3030
from nvtripy.common.datatype import DATA_TYPES
31-
from nvtripy.wrappers import TYPE_VERIFICATION
31+
from nvtripy.utils.wrappers import TYPE_VERIFICATION
3232

3333
PARAM_PAT = re.compile(":param .*?:")
3434

tripy/docs/post0_developer_guides/how-to-add-new-ops.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ it as a `nvtripy.Module` under [`frontend/module`](source:/nvtripy/frontend/modu
176176

177177
```py
178178
# doc: no-eval
179-
from nvtripy import export, wrappers
179+
from nvtripy import export
180+
from nvtripy.utils import wrappers
180181
from nvtripy.types import ShapeLike
181182

182183
# We can use the `export.public_api()` decorator to automatically export this

tripy/nvtripy/backend/api/compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def process_arg(name, arg):
153153
return arg
154154

155155
new_args = []
156-
positional_arg_info, _ = utils.get_positional_arg_names(func, *args)
156+
positional_arg_info, _ = utils.utils.get_positional_arg_names(func, *args)
157157
for name, arg in positional_arg_info:
158158
new_args.append(process_arg(name, arg))
159159

@@ -165,7 +165,7 @@ def process_arg(name, arg):
165165
# as `InputInfo`s, but the order needs to match the signature of the original function.
166166
compiled_arg_names = [name for name in signature.parameters.keys() if name in input_names]
167167

168-
trace_outputs = utils.make_list(func(*new_args, **new_kwargs))
168+
trace_outputs = utils.utils.make_list(func(*new_args, **new_kwargs))
169169

170170
if not trace_outputs:
171171
raise_error(

tripy/nvtripy/backend/api/executable.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,17 @@
1414
# limitations under the License.
1515
import base64
1616
import inspect
17-
from typing import Sequence, Union, Tuple
17+
from dataclasses import dataclass
18+
from typing import Sequence, Tuple, Union
1819

1920
import mlir_tensorrt.runtime.api as runtime
20-
2121
from nvtripy import export
2222
from nvtripy.backend.mlir import Executor
2323
from nvtripy.backend.mlir import utils as mlir_utils
2424
from nvtripy.common.exception import raise_error
2525
from nvtripy.frontend import Tensor
26-
from nvtripy.function_registry import str_from_type_annotation
2726
from nvtripy.utils import json as json_utils
28-
from dataclasses import dataclass
27+
from nvtripy.utils.types import str_from_type_annotation
2928

3029

3130
# TODO(MLIR-TRT #923): Can generalize `InputInfo` and drop this class.

tripy/nvtripy/backend/mlir/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def compile_stabehlo_program(self, code: str) -> compiler.Executable:
7676
return compiler.compiler_stablehlo_to_executable(self.compiler_client, module.operation, opts)
7777

7878
# The optional flat_ir parameter is used to generate nicer error messages.
79-
@utils.log_time
79+
@utils.utils.log_time
8080
def compile(self, mlir_module: ir.Module, flat_ir: Optional["FlatIR"] = None) -> compiler.Executable:
8181
logger.mlir(lambda: f"{mlir_module.operation.get_asm(large_elements_limit=32)}\n")
8282
opts = self._make_mlir_opts(self.trt_builder_opt_level)

tripy/nvtripy/backend/mlir/executor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818
from typing import List
1919

2020
import mlir_tensorrt.runtime.api as runtime
21-
2221
from nvtripy.backend.api.stream import default_stream
2322
from nvtripy.backend.mlir.memref import create_memref
2423
from nvtripy.backend.mlir.utils import MLIRRuntimeClient, convert_runtime_dtype_to_tripy_dtype
2524
from nvtripy.backend.utils import TensorInfo
2625
from nvtripy.common import datatype, device
2726
from nvtripy.common.exception import raise_error
2827
from nvtripy.common.utils import convert_list_to_array
29-
from nvtripy.utils import make_tuple
28+
from nvtripy.utils.utils import make_tuple
3029

3130

3231
class Executor:

tripy/nvtripy/backend/mlir/memref.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
# limitations under the License.
1616
#
1717

18-
import math
1918
import re
2019

2120
import mlir_tensorrt.runtime.api as runtime
22-
2321
from nvtripy.backend.mlir import utils as mlir_utils
2422
from nvtripy.common import device as tp_device
25-
from nvtripy.utils import raise_error
23+
from nvtripy.common.exception import raise_error
2624

2725
EMPTY_MEMREF_CACHE = {}
2826

tripy/nvtripy/backend/mlir/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def get_flat_ir_operation(output_names):
368368
return (
369369
[
370370
"This error occured while trying to compile the following FlatIR expression:",
371-
utils.code_pretty_str(str(op)),
371+
utils.utils.code_pretty_str(str(op)),
372372
"\n",
373373
]
374374
+ (

0 commit comments

Comments
 (0)