Skip to content

Commit afd4fb6

Browse files
authored
Speed up GroupNorm, improve normalization testing (#637)
Due to an API incompatibility issue with Torch GroupNorm and TensorRT GroupNorm, the implementation uses InstanceNorm instead as a workaround (following the same WAR used by the ONNX parser and Torch -> ONNX converter). The latest opset of ONNX supports scale and bias with shape `(num_channels,)`, so it is likely that the TRT API will see this eventually supported, at which point we can switch to the most direct implementation. Running 1k iterations in a loop shows the new implementation is roughly 17% faster on average. The nsys trace shows that the new implementation is significantly better fused (only two computation kernels, for the instancenorm and the affine transform), and the performance for a single iteration of the module is up to 40% faster (30µs vs 50µs).
1 parent e241e45 commit afd4fb6

File tree

9 files changed

+280
-110
lines changed

9 files changed

+280
-110
lines changed

tripy/nvtripy/frontend/module/groupnorm.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from nvtripy.frontend.module.parameter import DefaultParameter
2525
from nvtripy.frontend.tensor import Tensor
2626

27+
from nvtripy.frontend.module.instancenorm import InstanceNorm
28+
2729

2830
@export.public_api(document_under="operations/modules")
2931
@dataclass
@@ -35,6 +37,8 @@ class GroupNorm(Module):
3537
:math:`\text{GroupNorm}(x) = \Large \frac{x - \bar{x}}{ \sqrt{\sigma^2 + \epsilon}} \normalsize * \gamma + \beta`
3638
3739
where :math:`\bar{x}` is the mean and :math:`\sigma^2` is the variance.
40+
41+
The input should have shape :math:`[N, C, D1, ...]` where :math:`N` is the batch size, :math:`C` is the number of channels, and :math:`D1, ...` are the feature dimensions.
3842
"""
3943

4044
num_groups: int
@@ -68,30 +72,31 @@ def __init__(
6872
.. code-block:: python
6973
:linenos:
7074
71-
group_norm = tp.GroupNorm(2, 2)
75+
group_norm = tp.GroupNorm(2, 4)
7276
73-
group_norm.weight = tp.iota(group_norm.weight.shape)
74-
group_norm.bias = tp.iota(group_norm.bias.shape)
77+
group_norm.weight = tp.ones(group_norm.weight.shape)
78+
group_norm.bias = tp.zeros(group_norm.bias.shape)
7579
76-
input = tp.iota((1, 2, 2, 2), dim=1)
80+
input = tp.iota((1, 4, 1, 1), dim=1)
7781
output = group_norm(input)
7882
7983
np_out = cp.from_dlpack(output).get() # doc: omit
80-
assert np_out.shape == (1, 2, 2, 2)
84+
assert np_out.shape == (1, 4, 1, 1)
8185
8286
torch_tensor = torch.from_dlpack(input) # doc: omit
8387
torch_gn = torch.nn.GroupNorm(2, 2).to(torch.device("cuda")) # doc: omit
8488
torch_gn.weight.data = torch.from_dlpack(group_norm.weight) # doc: omit
8589
torch_gn.bias.data = torch.from_dlpack(group_norm.bias) # doc: omit
8690
torch_out = cp.from_dlpack(torch_gn(torch_tensor).detach()).get() # doc: omit
87-
assert np_out.shape == torch_out.shape # doc: omit
88-
assert np.allclose(np_out, torch_out) # doc: omit
91+
assert np_out.shape == torch_out.shape
92+
assert np.allclose(np_out, torch_out)
8993
"""
94+
9095
super().__init__()
9196

9297
if num_channels % num_groups:
9398
raise_error(
94-
"Number of groups must divide number of channels evenly.",
99+
"The number of groups must divide number of channels evenly.",
95100
details=[f"Got {num_groups} groups but {num_channels} channels."],
96101
)
97102

@@ -112,19 +117,30 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor":
112117
Returns:
113118
A tensor of the same shape as the input.
114119
"""
115-
from nvtripy.frontend.ops.reduce.mean import mean
116-
from nvtripy.frontend.ops.reduce.var import var
117-
from nvtripy.frontend.ops.reshape import reshape
118-
from nvtripy.frontend.ops.unary.rsqrt import rsqrt
119-
120-
input_shape = x.shape
121120

122-
x = reshape(x, (x.shape[0], self.num_groups, -1))
123-
mean_val = mean(x, dim=-1, keepdim=True)
124-
var_val = var(x, dim=-1, keepdim=True, correction=0) + self.eps
125-
x = (x - mean_val) * rsqrt(var_val)
126-
x = reshape(x, input_shape)
127-
128-
shape_to_broadcast = (1, self.num_channels) + (1,) * (x.rank - 2)
121+
if x.rank < 3:
122+
raise_error(
123+
f"Input must have a rank of at least 3, but got input of rank: {x.rank}",
124+
details=[
125+
"The input should have shape [N, C, D1, ...] where N is the batch size, C is the number of channels, and D1, ... are the feature dimensions."
126+
],
127+
)
129128

130-
return reshape(self.weight, shape_to_broadcast) * x + reshape(self.bias, shape_to_broadcast)
129+
from nvtripy.frontend.ops.reshape import reshape
130+
from nvtripy.frontend.ops.split import split
131+
from nvtripy.frontend.ops.stack import stack
132+
from nvtripy.frontend.ops.flatten import flatten
133+
from nvtripy.frontend.module.instancenorm import InstanceNorm
134+
from nvtripy.frontend.ops.ones import ones
135+
from nvtripy.frontend.ops.zeros import zeros
136+
137+
instance_norm = InstanceNorm(self.num_groups, dtype=self.dtype, eps=self.eps)
138+
instance_norm.weight = ones((self.num_groups,), dtype=self.dtype)
139+
instance_norm.bias = zeros((self.num_groups,), dtype=self.dtype)
140+
141+
# Use InstanceNorm as a WAR due to lack of TRT API compatibility for scale/bias with shape (num_channels, )
142+
input_reshaped = stack(split(x, self.num_groups, 1), 1)
143+
x = instance_norm(input_reshaped)
144+
x = flatten(x, start_dim=1, end_dim=2)
145+
broadcast_shape = (1, self.num_channels) + (1,) * (x.rank - 2)
146+
return x * reshape(self.weight, broadcast_shape) + reshape(self.bias, broadcast_shape)

tripy/nvtripy/frontend/module/instancenorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def instancenorm(
4545

4646
if input_rank < 3:
4747
raise_error(
48-
f"InstanceNorm input must have a rank of at least 3, but got input of rank: {input.rank}",
48+
f"Input must have a rank of at least 3, but got input of rank: {input.rank}",
4949
details=[
5050
"Input is expected to have shape (N, C, D1, ...) where N is the batch size, C is the number of channels, and D1, ... are the spatial dimensions"
5151
],

tripy/nvtripy/frontend/module/layernorm.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from nvtripy import export, utils
2222
from nvtripy.common import datatype
23+
from nvtripy.common.exception import raise_error
2324
from nvtripy.frontend.module.module import Module
2425
from nvtripy.frontend.module.parameter import DefaultParameter
2526
from nvtripy.frontend.tensor import Tensor
@@ -44,10 +45,17 @@ def layernorm(
4445
D = len(normalized_shape)
4546
input_rank = input.rank
4647

47-
# Reshape weight and bias to match input rank for TensorRT normalization (expects [1, ...] + normalized_shape)
48-
if input_rank > D:
49-
from nvtripy.frontend.ops.reshape import reshape
48+
if input_rank < 2:
49+
raise_error(
50+
f"Input must have a rank of at least 2, but got input of rank: {input.rank}",
51+
details=[
52+
"Input is expected to have shape (N, *) where N is the batch size, and * represents any number of channel dimension + spatial dimensions"
53+
],
54+
)
55+
56+
from nvtripy.frontend.ops.reshape import reshape
5057

58+
if input_rank > D:
5159
broadcast_shape = (1,) * (input_rank - D) + normalized_shape
5260
weight = reshape(weight, broadcast_shape)
5361
bias = reshape(bias, broadcast_shape)

tripy/nvtripy/trace/ops/layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
@dataclass(repr=False)
2828
class LayerNorm(TraceOp):
2929
normalized_shape: Sequence[int]
30-
eps: float = 1e-5
30+
eps: float
3131

3232
infer_rank = op_utils.InferRankPolicies.same_as_input()
3333

tripy/tests/frontend/module/test_instancenorm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,26 @@ def test_instancenorm_improper_rank(self):
2424
tp_instancenorm = tp.InstanceNorm(
2525
num_channels=3,
2626
)
27-
tp_instancenorm.weight = tp.ones((3,))
28-
tp_instancenorm.bias = tp.ones((3,))
27+
tp_instancenorm.initialize_dummy_parameters()
2928

3029
x = tp.ones((2, 3))
3130
with helper.raises(
3231
tp.TripyException,
33-
match=f"InstanceNorm input must have a rank of at least 3, but got input of rank: {x.rank}",
32+
match=f"Input must have a rank of at least 3, but got input of rank: {x.rank}",
3433
):
3534
tp_instancenorm(x).eval()
3635

3736
def test_instancenorm_improper_channels(self):
3837
tp_instancenorm = tp.InstanceNorm(
3938
num_channels=3,
4039
)
41-
tp_instancenorm.weight = tp.ones((3,))
42-
tp_instancenorm.bias = tp.ones((3,))
40+
tp_instancenorm.initialize_dummy_parameters()
4341

4442
# dynamic shape
4543
x = tp.ones((2, 6, 4, 4))
4644
with helper.raises(
4745
tp.TripyException,
48-
match="MTRTException: failed to run pass pipeline",
46+
match=r"'tensorrt.slice' op inferred type\(s\) 'tensor\<2x6x4x4xf32\>' are incompatible with return type\(s\) of operation 'tensor\<\?x3x\?x\?xf32\>'",
4947
):
5048
tp_instancenorm(x).eval()
5149

tripy/tests/frontend/module/test_layernorm.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,23 @@ def test_layernorm_improper_dimensions(self):
2323
tp_layernorm = tp.LayerNorm(
2424
normalized_shape=[2, 2],
2525
)
26-
tp_layernorm.weight = tp.ones((2, 2))
27-
tp_layernorm.bias = tp.ones((2, 2))
26+
tp_layernorm.initialize_dummy_parameters()
2827

2928
x = tp.ones((5, 5, 5))
3029
with helper.raises(
3130
tp.TripyException, match="The normalization scale is not broadcast-compatible with the input at dimension 1"
3231
):
3332
tp_layernorm(x).eval()
33+
34+
def test_layernorm_improper_rank(self):
35+
tp_layernorm = tp.LayerNorm(
36+
normalized_shape=[2],
37+
)
38+
tp_layernorm.initialize_dummy_parameters()
39+
40+
x = tp.ones((2,))
41+
with helper.raises(
42+
tp.TripyException,
43+
match=f"Input must have a rank of at least 2, but got input of rank: {x.rank}",
44+
):
45+
tp_layernorm(x).eval()

tripy/tests/integration/test_groupnorm.py

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,45 +15,95 @@
1515
# limitations under the License.
1616
#
1717

18+
import nvtripy as tp
1819
import pytest
1920
import torch
2021

21-
import nvtripy as tp
22+
from tests.helper import TORCH_DTYPES
23+
24+
DTYPES = [tp.float16, tp.float32]
25+
26+
dtype_params = pytest.mark.parametrize("dtype", DTYPES)
27+
input_shape_params = pytest.mark.parametrize("input_shape", [(1, 6, 2, 2)])
28+
num_groups_params = pytest.mark.parametrize("num_groups", [2, 3])
29+
num_channels_params = pytest.mark.parametrize("num_channels", [6])
30+
2231

23-
DTYPES = [(torch.float16, tp.float16), (torch.float32, tp.float32)]
32+
@pytest.fixture
33+
def setup(dtype, input_shape, num_groups, num_channels):
34+
eps = 0.0
35+
torch_dtype = TORCH_DTYPES[dtype]
36+
groupnorm = torch.nn.GroupNorm(
37+
num_groups=num_groups,
38+
num_channels=num_channels,
39+
eps=eps,
40+
dtype=torch_dtype,
41+
device="cuda",
42+
)
43+
tp_groupnorm = tp.GroupNorm(
44+
num_groups=num_groups,
45+
num_channels=num_channels,
46+
eps=eps,
47+
dtype=dtype,
48+
)
49+
50+
input = torch.empty(*input_shape, dtype=torch_dtype, device="cuda").uniform_(0, 10)
51+
tp_input = tp.Tensor(input, dtype=dtype)
52+
yield groupnorm, tp_groupnorm, tp_input
2453

2554

2655
class TestGroupNorm:
2756

28-
@pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES)
29-
@pytest.mark.parametrize("input_shape", [(1, 10, 2)])
30-
@pytest.mark.parametrize("num_groups", [2, 5])
31-
@pytest.mark.parametrize("num_channels", [10])
32-
def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups, num_channels, eager_or_compiled):
33-
eps = 1e-5
34-
groupnorm = torch.nn.GroupNorm(
35-
num_groups=num_groups,
36-
num_channels=num_channels,
37-
eps=eps,
38-
dtype=torch_dtype,
39-
device="cuda",
40-
)
41-
tp_groupnorm = tp.GroupNorm(
42-
num_groups=num_groups,
43-
num_channels=num_channels,
44-
eps=eps,
45-
dtype=tp_dtype,
46-
)
57+
@dtype_params
58+
@input_shape_params
59+
@num_groups_params
60+
@num_channels_params
61+
def test_groupnorm_normalization(self, input_shape, num_groups, setup, eager_or_compiled):
62+
"""Test that normalized output has approximately mean=0, std=1"""
63+
_, tp_groupnorm, tp_input = setup
64+
dtype = tp_groupnorm.weight.dtype
65+
66+
tp_groupnorm.weight = tp.ones(tp_groupnorm.weight.shape, dtype=dtype)
67+
tp_groupnorm.bias = tp.zeros(tp_groupnorm.bias.shape, dtype=dtype)
68+
69+
output = eager_or_compiled(tp_groupnorm, tp_input)
70+
output_torch = torch.from_dlpack(output)
71+
72+
N, C = input_shape[0], input_shape[1]
73+
spatial_size = torch.prod(torch.tensor(input_shape[2:]))
74+
reshaped = output_torch.view(N, num_groups, C // num_groups, spatial_size)
75+
76+
means = reshaped.mean(dim=(2, 3))
77+
vars = reshaped.var(dim=(2, 3), unbiased=False)
78+
79+
mean_abs = means.abs().mean().item()
80+
var_diff = (vars - 1).abs().mean().item()
81+
82+
assert mean_abs < 2e-4, f"Group mean should be close to 0, got {mean_abs}"
83+
assert var_diff < 1e-3, f"Group variance should be close to 1, got {var_diff}"
84+
85+
@dtype_params
86+
@input_shape_params
87+
@num_groups_params
88+
@num_channels_params
89+
def test_groupnorm_affine_transformation(self, setup, eager_or_compiled):
90+
"""Test the GroupNorm with affine transformation included"""
91+
groupnorm, tp_groupnorm, tp_input = setup
92+
dtype = tp_groupnorm.weight.dtype
93+
input = torch.from_dlpack(tp_input)
94+
95+
torch.nn.init.uniform_(groupnorm.weight, 0.2, 2)
96+
torch.nn.init.uniform_(groupnorm.bias, 0.2, 2)
4797

4898
tp_groupnorm.weight = tp.Tensor(groupnorm.weight.to("cpu").detach())
4999
tp_groupnorm.bias = tp.Tensor(groupnorm.bias.to("cpu").detach())
50100

51-
input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype).to("cuda")
52-
tp_input = tp.Tensor(input, dtype=tp_dtype)
53-
54101
output = eager_or_compiled(tp_groupnorm, tp_input)
55102
with torch.no_grad():
56103
expected = groupnorm(input)
57104

58-
rtol_ = 2e-6 if tp_dtype == tp.float32 else 1e-3
59-
assert torch.allclose(torch.from_dlpack(output), expected, rtol=rtol_)
105+
atol_ = 1e-6 if dtype == tp.float32 else 5e-3
106+
107+
torch_output = torch.from_dlpack(output)
108+
assert torch_output.shape == expected.shape
109+
assert torch.allclose(torch_output, expected, atol=atol_)

0 commit comments

Comments
 (0)