You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Speed up GroupNorm, improve normalization testing (#637)
Due to an API incompatibility issue with Torch GroupNorm and TensorRT
GroupNorm, the implementation uses InstanceNorm instead as a workaround
(following the same WAR used by the ONNX parser and Torch -> ONNX
converter).
The latest opset of ONNX supports scale and bias with shape
`(num_channels,)`, so it is likely that the TRT API will see this
eventually supported, at which point we can switch to the most direct
implementation.
Running 1k iterations in a loop shows the new implementation is roughly
17% faster on average. The nsys trace shows that the new implementation
is significantly better fused (only two computation kernels, for the
instancenorm and the affine transform), and the performance for a single
iteration of the module is up to 40% faster (30µs vs 50µs).
where :math:`\bar{x}` is the mean and :math:`\sigma^2` is the variance.
40
+
41
+
The input should have shape :math:`[N, C, D1, ...]` where :math:`N` is the batch size, :math:`C` is the number of channels, and :math:`D1, ...` are the feature dimensions.
match=f"InstanceNorm input must have a rank of at least 3, but got input of rank: {x.rank}",
32
+
match=f"Input must have a rank of at least 3, but got input of rank: {x.rank}",
34
33
):
35
34
tp_instancenorm(x).eval()
36
35
37
36
deftest_instancenorm_improper_channels(self):
38
37
tp_instancenorm=tp.InstanceNorm(
39
38
num_channels=3,
40
39
)
41
-
tp_instancenorm.weight=tp.ones((3,))
42
-
tp_instancenorm.bias=tp.ones((3,))
40
+
tp_instancenorm.initialize_dummy_parameters()
43
41
44
42
# dynamic shape
45
43
x=tp.ones((2, 6, 4, 4))
46
44
withhelper.raises(
47
45
tp.TripyException,
48
-
match="MTRTException: failed to run pass pipeline",
46
+
match=r"'tensorrt.slice' op inferred type\(s\) 'tensor\<2x6x4x4xf32\>' are incompatible with return type\(s\) of operation 'tensor\<\?x3x\?x\?xf32\>'",
0 commit comments