Skip to content

Commit b5ab1ac

Browse files
committed
tests: adding a testcase for save api + ts and some type hints
1 parent b9d6a4c commit b5ab1ac

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import platform
66
from enum import Enum
7-
from typing import Any, Callable, List, Optional, Sequence, Set, Union
7+
from typing import Any, Callable, List, Optional, Sequence, Set, Union, Literal
88

99
import torch
1010
import torch.fx
@@ -580,7 +580,9 @@ def save(
580580
module: Any,
581581
file_path: str = "",
582582
*,
583-
output_format: str = "exported_program",
583+
output_format: Literal[
584+
"exported_program", "torchscript", "aot_inductor"
585+
] = "exported_program",
584586
inputs: Optional[Sequence[torch.Tensor]] = None,
585587
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
586588
kwarg_inputs: Optional[dict[str, Any]] = None,

tests/py/ts/api/test_export_serde.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import importlib
2+
import os
3+
import platform
4+
import tempfile
5+
import unittest
6+
7+
import pytest
8+
import torch
9+
import torch_tensorrt as torchtrt
10+
from torch_tensorrt.dynamo.utils import (
11+
COSINE_THRESHOLD,
12+
cosine_similarity,
13+
get_model_device,
14+
)
15+
16+
assertions = unittest.TestCase()
17+
18+
@pytest.mark.unit
19+
def test_save_load_ts(ir):
20+
"""
21+
This tests save/load API on Torchscript format (model still compiled using ts workflow)
22+
"""
23+
24+
class MyModule(torch.nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
28+
self.relu = torch.nn.ReLU()
29+
30+
def forward(self, x):
31+
conv = self.conv(x)
32+
relu = self.relu(conv)
33+
mul = relu * 0.5
34+
return mul
35+
36+
model = MyModule().eval().cuda()
37+
input = torch.randn((1, 3, 224, 224)).to("cuda")
38+
39+
trt_gm = torchtrt.compile(
40+
model,
41+
ir="ts",
42+
inputs=[input],
43+
min_block_size=1,
44+
cache_built_engines=False,
45+
reuse_cached_engines=False,
46+
)
47+
outputs_trt = trt_gm(input)
48+
# Save it as torchscript representation
49+
torchtrt.save(trt_gm, "./trt.ts", output_format="torchscript", inputs=[input])
50+
51+
trt_ts_module = torchtrt.load("./trt.ts")
52+
outputs_trt_deser = trt_ts_module(input)
53+
54+
cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)
55+
assertions.assertTrue(
56+
cos_sim > COSINE_THRESHOLD,
57+
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
58+
)

0 commit comments

Comments
 (0)