Skip to content
Open
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
termcolor
4 changes: 2 additions & 2 deletions scripts/build_contrib.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ pushd /tmp/TensorRT
git sparse-checkout set /tools/pytorch-quantization/
git apply --reject --whitespace=fix pytorch_nvidia_quantization.patch
cd tools/pytorch-quantization/
python setup.py install
sudo python3 setup.py install
popd

pushd $parentdir
python3 setup.py install --plugins --contrib
sudo python3 setup.py install --plugins --contrib
popd


5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
from packaging import version

REQUIREMENTS_PATH = 'requirements.txt'

with open(REQUIREMENTS_PATH, 'r') as file:
required_libraries = file.read().splitlines()

def trt_inc_dir():
return "/usr/include/aarch64-linux-gnu"
Expand Down Expand Up @@ -55,5 +59,6 @@ def trt_lib_dir():
packages=find_packages(exclude=exclude_dir),
ext_package='torch2trt',
ext_modules=ext_modules,
install_requires=required_libraries,
cmdclass={'build_ext': BuildExtension}
)
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@
from .transpose import *
from .unary import *
from .view import *
from .zeros import *
65 changes: 60 additions & 5 deletions torch2trt/converters/avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,55 @@
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.nn.functional.avg_pool1d')
def convert_avg_pool1d(ctx):
# At the time of this implementation, TensorRT 8.x does not yet support avg pooling in 1D using `add_pooling_nd(...)`.
# As such, we use a workaround here, by unsqueezing another dimension into the input (thus transforming it from
# (N, C, L) to (N, C, L, 1)) so that we can use 2D max pooling across the last three dimensions.

input = get_arg(ctx, 'input', pos=0, default=None)
input_trt = trt_(ctx.network, input)
output = ctx.method_return

kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None)
stride = get_arg(ctx, 'stride', pos=2, default=None)
padding = get_arg(ctx, 'padding', pos=3, default=0)
ceil_mode = get_arg(ctx, 'ceil_mode', pos=4, default=False)
count_include_pad = get_arg(ctx, 'count_include_pad', pos=5, default=True)

# Convert inputs to be 2d compatible as inputs will always be 1d.
kernel_size = (kernel_size, 1)
stride = kernel_size if not stride else (stride, 1)
padding = (padding, 0)

# Shuffle layer to unsqueeze another dimension for 2D max pooling.
unsqueeze_layer = ctx.network.add_shuffle(input_trt)
set_layer_precision(ctx, unsqueeze_layer)
unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1])
unsqueeze_trt = unsqueeze_layer.get_output(0)

# Use 2D max pooling here to fake 1D max pooling.
layer = ctx.network.add_pooling_nd(
input=unsqueeze_trt,
type=trt.PoolingType.AVERAGE,
window_size=kernel_size,
)
set_layer_precision(ctx, layer)
layer.stride_nd = stride
layer.padding_nd = padding
layer.average_count_excludes_padding = not count_include_pad

if ceil_mode:
layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

pooling_trt = layer.get_output(0)

# Shuffle layer to squeeze out dimension that was just added for 2D max pooling so return is still in 1D.
squeeze_layer = ctx.network.add_shuffle(pooling_trt)
set_layer_precision(ctx, squeeze_layer)
squeeze_layer.reshape_dims = tuple(pooling_trt.shape[:-1])
output._trt = squeeze_layer.get_output(0)

@tensorrt_converter("torch.nn.functional.avg_pool2d", enabled=trt_version() < '7.0')
def convert_avg_pool2d(ctx):
# parse args
Expand Down Expand Up @@ -83,12 +132,14 @@ def convert_avg_pool_trt7(ctx):
layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 5, 7)])
def test_avg_pool2d_without_ceil_mode():
return torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
return torch.nn.AvgPool2d(
kernel_size=3, stride=2, padding=1, ceil_mode=False
)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6)])
Expand All @@ -102,10 +153,14 @@ def test_avg_pool2d_with_ceil_mode():
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 4, 6)], enabled=trt_version() >= '7.0')
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 5, 7)], enabled=trt_version() >= '7.0')
def test_avg_pool3d_without_ceil_mode_trt7():
return torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
return torch.nn.AvgPool3d(
kernel_size=3, stride=2, padding=1, ceil_mode=False
)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 4, 6)], enabled=trt_version() >= '7.0')
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 5, 7)], enabled=trt_version() >= '7.0')
def test_avg_pool3d_with_ceil_mode_trt7():
return torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True, count_include_pad=False) # TRT does not support ceil_mode=True && count_include_pad=True
return torch.nn.AvgPool3d(
kernel_size=3, stride=2, padding=1, ceil_mode=True, count_include_pad=False
) # TRT does not support ceil_mode=True && count_include_pad=True
2 changes: 1 addition & 1 deletion torch2trt/converters/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.Tensor.expand_as')
@tensorrt_converter('torch.Tensor.expand')
def convert_expand(ctx):
input = ctx.method_args[0]
sizes = ctx.method_args[1:]
output = ctx.method_return

inshape = tuple(input.shape)[1:] # exclude batch
Expand Down
66 changes: 66 additions & 0 deletions torch2trt/converters/zeros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


def _set_layer_precision(ctx, layer):
# Supported TRT precisions as given by torch2trt_kwargs.
INT8_MODE = "int8_mode"
FP16_MODE = "fp16_mode"

# Check that args exist as expected in torch2trt_kwargs.
trt_kwargs = ctx.torch2trt_kwargs
assert INT8_MODE in trt_kwargs
assert FP16_MODE in trt_kwargs

is_int8 = trt_kwargs.get(INT8_MODE, False)
is_fp16 = trt_kwargs.get(FP16_MODE, False)

if is_int8:
layer.precision = trt.int8
layer.set_output_type(0, trt.int8)
elif is_fp16:
layer.precision = trt.float16
layer.set_output_type(0, trt.float16)


@tensorrt_converter('torch.zeros')
def convert_zeros(ctx):
tensor = ctx.method_return

# Implementation copied from add_trt_constant.
shape = tuple(tensor.shape[1:])
array = tensor[0].detach().cpu().numpy()
layer = ctx.network.add_constant(shape, array)

_set_layer_precision(ctx, layer)

tensor._trt = layer.get_output(0)


class Zeros(torch.nn.Module):
def __init__(self, *size):
super().__init__()
self.size = size

def forward(self, x):
return x + torch.zeros(*self.size, device=torch.device('cuda'))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)])
def test_zeros():
return Zeros((1, 2, 3, 4))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)])
def test_zeros_var_args():
return Zeros(1, 2, 3, 4)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)], fp16_mode=True)
def test_zeros_fp16_mode():
return Zeros(1, 2, 3, 4)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)], int8_mode=True)
def test_zeros_int8_mode():
return Zeros(1, 2, 3, 4)
4 changes: 0 additions & 4 deletions torch2trt/module_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import torch
import torchvision


class ModuleTest(object):
def __init__(self, module_fn, dtype, device, input_shapes, **torch2trt_kwargs):
self.module_fn = module_fn
Expand Down