Skip to content

Commit b832361

Browse files
pranavm-nvidiajhalakpatelyizhuoz004
authored
Migrates Tripy to use the tensorrt dialect instead of StableHLO (#607)
Migrates Tripy to use the `tensorrt` dialect instead of StableHLO. The latter is a lower level representation that the TensorRT network API, meaning that we first break down high level Tripy operations and then pattern match them back up. Using the `tensorrt` dialect allows us to go directly to TensorRT without so many intervening layers of translation, which reduces complexity and bugs, and improves performance. Broadly, this change does the following: - Removes FlatIR - Make Trace responsible for mapping to MLIR - Trace operations represent MLIR operations 1:1 - Makes the frontend API responsible for composing Trace operations --------- Signed-off-by: yizhuoz004 <[email protected]> Co-authored-by: Jhalak Patel <[email protected]> Co-authored-by: Yizhuo Zhang <[email protected]>
1 parent 6a9cfa9 commit b832361

File tree

353 files changed

+8030
-11910
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

353 files changed

+8030
-11910
lines changed

tripy/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,6 @@ core*
8888

8989
# PyTest profiling results
9090
*.prof
91+
92+
# PyTorch checkpoints
93+
*.pt.*

tripy/CONTRIBUTING.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ pre-commit install
7070
7171
We've written developer guides to help you understand the codebase:
7272

73+
<!-- TODO (pranavm): Update links here -->
7374
- Start with the
7475
[architecture](https://nvidia.github.io/TensorRT-Incubator/post0_developer_guides/architecture.html)
7576
documentation.

tripy/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ guide for details:
4545
def __init__(self):
4646
self.conv = tp.Conv(in_channels=1, out_channels=1, kernel_dims=[3, 3])
4747

48-
def __call__(self, x):
48+
def forward(self, x):
4949
x = self.conv(x)
5050
x = tp.relu(x)
5151
return x

tripy/docs/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ Guides are **markdown** files parsed by the Myst parser, therefore:
101101
- **Example:**
102102

103103
```md
104-
[Fill operation](source:/nvtripy/trace/ops/fill.py)
104+
[Broadcast operation](source:/nvtripy/trace/ops/broadcast.py)
105105
```
106106

107107
- **Why:** Other links will cause the file to be downloaded instead of linking to the repository.

tripy/docs/conf.py

Lines changed: 7 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -27,8 +27,6 @@
2727
from tests import helper
2828

2929
import nvtripy as tp
30-
from nvtripy.common.datatype import DATA_TYPES
31-
from nvtripy.utils.wrappers import TYPE_VERIFICATION
3230

3331
PARAM_PAT = re.compile(":param .*?:")
3432

@@ -179,10 +177,10 @@ def process_docstring_impl(app, what, name, obj, options, lines):
179177
pname = "*" + pname
180178

181179
# Type annotations are optional for the `self` parameter unless the API has to be type-verified.
182-
if pname != "self" or name in TYPE_VERIFICATION:
180+
if pname != "self":
183181
assert (
184182
pname in documented_args
185-
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args} {doc}"
183+
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args}"
186184
assert (
187185
pname in documented_args
188186
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args}"
@@ -206,69 +204,12 @@ def process_docstring_impl(app, what, name, obj, options, lines):
206204
":returns:" in doc
207205
), f"For: {obj}, return value is not documented. Please ensure you've included a `Returns:` section"
208206

209-
if name in TYPE_VERIFICATION:
210-
add_text_index = -1
211-
for index, block in enumerate(blocks):
212-
213-
def insert_block(text):
214-
nonlocal index
215-
216-
blocks.insert(index, text)
217-
index += 1
218-
219-
if re.search(r".. code-block::", block):
220-
type_dict = TYPE_VERIFICATION[name].dtypes
221-
insert_block("TYPE CONSTRAINTS:")
222-
# Add the dtype constraint name and the dtypes that correlate.
223-
for type_name, dt in type_dict.items():
224-
insert_block(
225-
f" - **{type_name}**: :class:`"
226-
+ "`, :class:`".join(
227-
sorted(
228-
set(dt),
229-
key=lambda dtype: (
230-
tuple(typ.__name__ for typ in DATA_TYPES[dtype].__bases__),
231-
DATA_TYPES[dtype].itemsize,
232-
),
233-
)
234-
)
235-
+ "`",
236-
)
237-
insert_block("\n")
238-
239-
if TYPE_VERIFICATION[name].exceptions:
240-
# Add the dtype exceptions.
241-
insert_block("UNSUPPORTED TYPE COMBINATIONS:")
242-
for exception_dict in TYPE_VERIFICATION[name].exceptions:
243-
insert_block(
244-
" - "
245-
+ ", ".join([f"**{key}**\ =\ :class:`{val}`" for key, val in exception_dict.items()]),
246-
)
247-
insert_block("\n")
248-
break
249-
250-
if re.search(r":param \w+: ", block):
251-
param_name = re.match(r":param (\w+): ", block).group(1)
252-
# Add dtype constraint to start of each parameter description.
253-
if TYPE_VERIFICATION[name].constraints.get(param_name, None):
254-
add_text_index = re.search(r":param \w+: ", block).span()[1]
255-
blocks[index] = (
256-
f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[name].constraints[param_name]}**\ ] {block[add_text_index:]}"
257-
)
258-
259-
if TYPE_VERIFICATION[name].return_dtype is not None and re.search(r":returns:", block):
260-
add_text_index = re.search(r":returns:", block).span()[1] + 1
261-
# Add dtype constraint to start of returns description.
262-
blocks[index] = (
263-
f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[name].return_dtype}**\ ] {block[add_text_index:]}"
264-
)
265-
266207
seen_classes.add(name)
267208

268209
def allow_no_example():
269-
# `tp.Module`s include examples in their constructors, so their __call__ methods don't require examples.
210+
# `tp.Module`s include examples in their constructors, so their forward methods don't require examples.
270211
is_tripy_module_call_method = False
271-
if what == "method" and obj.__name__ == "__call__":
212+
if what == "method" and obj.__name__ == "forward":
272213
class_name = "nvtripy." + name.rpartition(".")[0]
273214
# Class names are prefixed with nvtripy.<...>, so we need to import it here to make eval() work.
274215
import nvtripy
@@ -336,7 +277,8 @@ def process_docstring(app, what, name, obj, options, lines):
336277
try:
337278
process_docstring_impl(app, what, name, obj, options, lines)
338279
except:
339-
print(f"Error while processing {what}: {name} ({obj})")
280+
sep = "\n"
281+
print(f"Error while processing {what}: {name} ({obj}).\nNote: Docstring was: {sep.join(lines)}")
340282
raise
341283

342284

tripy/docs/generate_rsts.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,16 +19,15 @@
1919
import inspect
2020
import os
2121
import re
22-
import shutil
2322
import subprocess as sp
2423
from collections import defaultdict
2524
from dataclasses import dataclass
2625
from textwrap import dedent, indent
2726
from typing import Dict, List, Set
2827

2928
import nvtripy as tp
30-
from tests import helper
3129
from nvtripy.export import PUBLIC_APIS
30+
from tests import helper
3231

3332

3433
@dataclass
@@ -93,7 +92,7 @@ def build_api_doc(api, include_heading=True):
9392
)
9493

9594

96-
def build_markdown_doc(source_path):
95+
def build_rst_markdown_include(source_path):
9796
return dedent(
9897
f"""
9998
.. include:: {source_path}
@@ -181,7 +180,24 @@ def build_index_file(name, constituents, include_heading=True, caption=None):
181180
)
182181

183182

183+
def write_if_different(filepath, content):
184+
if os.path.exists(filepath):
185+
with open(filepath, "r") as f:
186+
existing_content = f.read()
187+
if existing_content == content:
188+
print(f"No changes for {filepath}")
189+
# Update the file timestamp so we know its up-to-date in timestamp checks.
190+
os.utime(filepath, None)
191+
return
192+
with open(filepath, "w") as f:
193+
f.write(content)
194+
print(f"Updated {filepath}")
195+
196+
184197
def process_guide(guide_path: str, processed_guide_path: str):
198+
if os.path.exists(processed_guide_path) and os.path.getmtime(guide_path) <= os.path.getmtime(processed_guide_path):
199+
print(f"Skipping {guide_path} as it is not newer than {processed_guide_path}")
200+
return
185201

186202
os.makedirs(os.path.dirname(processed_guide_path), exist_ok=True)
187203

@@ -194,7 +210,7 @@ def process_guide(guide_path: str, processed_guide_path: str):
194210
for index, block in enumerate(blocks):
195211
print(f"Processing block {index} (lang={block.lang}) in: {guide_path}: ", end="")
196212

197-
should_eval = not block.has_marker("doc: no_eval")
213+
should_eval = not block.has_marker("doc: no_eval_or_format")
198214
if should_eval and block.lang.startswith("py"):
199215
print("Evaluating Python block")
200216

@@ -229,6 +245,7 @@ def add_block(kind, contents, lang, code_indentation):
229245
format_contents=add_block,
230246
err_msg=f"Error while executing code block {index} (line {block.line_number}) from {guide_path}. ",
231247
local_vars=code_locals,
248+
force_no_print_locals=block.has_marker("doc: no_print_locals"),
232249
)
233250
)
234251

@@ -254,8 +271,7 @@ def add_block(kind, contents, lang, code_indentation):
254271
else:
255272
print("Omitting block")
256273

257-
with open(processed_guide_path, "w") as fout:
258-
fout.write("\n".join(new_blocks))
274+
write_if_different(processed_guide_path, "\n".join(new_blocks))
259275

260276

261277
def main():
@@ -272,8 +288,6 @@ def main():
272288

273289
args = parser.parse_args()
274290

275-
shutil.rmtree(args.output, ignore_errors=True)
276-
277291
def make_output_path(*components):
278292
return os.path.join(args.output, *components)
279293

@@ -282,6 +296,7 @@ def make_output_path(*components):
282296
def is_file(document_under):
283297
return bool(os.path.splitext(document_under)[1])
284298

299+
rst_files = {}
285300
seen_apis = set()
286301
for api in PUBLIC_APIS:
287302
name = get_name(api)
@@ -314,9 +329,12 @@ def is_file(document_under):
314329

315330
os.makedirs(os.path.dirname(rst_path), exist_ok=True)
316331

317-
pre_existing_file = os.path.exists(rst_path)
318-
with open(rst_path, "a") as f:
319-
f.write(build_api_doc(api, include_heading=not pre_existing_file))
332+
# Only include a heading if we didn't already create the file:
333+
include_heading = False
334+
if rst_path not in rst_files:
335+
include_heading = True
336+
rst_files[rst_path] = ""
337+
rst_files[rst_path] += build_api_doc(api, include_heading=include_heading)
320338

321339
def str_from_hierarchy(obj):
322340
if isinstance(obj, dict):
@@ -374,13 +392,11 @@ def str_from_hierarchy(obj):
374392

375393
# Do not create RSTs for top-level README
376394
if not is_top_level_dir:
377-
with open(guide_out_path, "w") as f:
378-
# Grab the path of the .md file relative to the directory containing the .rst file so we can include it correctly.
379-
f.write(
380-
build_markdown_doc(
381-
source_path=os.path.relpath(processed_guide, os.path.dirname(guide_out_path))
382-
)
383-
)
395+
content = build_rst_markdown_include(
396+
source_path=os.path.relpath(processed_guide, os.path.dirname(guide_out_path))
397+
)
398+
rst_files[guide_out_path] = content
399+
384400
guides.append(os.path.join(basename, os.path.splitext(guide_filename)[0]))
385401

386402
# We have special treatment for the files in the top-level directory.
@@ -391,24 +407,33 @@ def str_from_hierarchy(obj):
391407
is_root = path == ""
392408

393409
index_path = make_output_path(path, "index.rst")
394-
pre_existing_file = os.path.exists(index_path)
410+
pre_existing_file = index_path in rst_files
395411

396412
if is_root:
397413
assert (
398414
not pre_existing_file
399415
), f"APIs should *not* target the root index file directly! Please remove any `document_under='index.rst'` arguments!"
400-
401-
with open(index_path, "a") as f:
402-
f.write(
403-
build_root_index_file(constituents, guide_sets, processed_markdown_dirname)
404-
if is_root
405-
else build_index_file(
406-
name=to_title(os.path.basename(os.path.dirname(index_path))),
407-
constituents=constituents,
408-
include_heading=not pre_existing_file,
409-
caption="See also:" if pre_existing_file else None,
410-
)
416+
content = build_root_index_file(constituents, guide_sets, processed_markdown_dirname)
417+
else:
418+
content = build_index_file(
419+
name=to_title(os.path.basename(os.path.dirname(index_path))),
420+
constituents=constituents,
421+
include_heading=not pre_existing_file,
422+
caption="See also:" if pre_existing_file else None,
411423
)
424+
if not pre_existing_file:
425+
rst_files[index_path] = ""
426+
rst_files[index_path] += content
427+
428+
# Delete any existing RST files that are not in rst_files
429+
for existing_file in glob.glob(os.path.join(args.output, "**", "*.rst"), recursive=True):
430+
if existing_file not in rst_files:
431+
print(f"Removing stale file: {existing_file}")
432+
os.remove(existing_file)
433+
434+
# Write the new/updated RST files
435+
for path, content in rst_files.items():
436+
write_if_different(path, content)
412437

413438

414439
if __name__ == "__main__":

0 commit comments

Comments
 (0)