Skip to content

Commit 3cab972

Browse files
authored
Update LayerNorm2D to use TRT normalization API (#650)
Naive performance test results indicate the new implementation is about 50% faster ``` LayerNorm2d_TP1 (Direct Implementation) took 0.0059s for 100 iterations LayerNorm2d_TP2 (LayerNorm-based) took 0.0038s for 100 iterations Accuracy (predicted mask scores) for SamV2 demo remains the same ```
1 parent 3ae73ea commit 3cab972

File tree

1 file changed

+5
-17
lines changed
  • tripy/examples/segment-anything-model-v2/sam2/modeling

1 file changed

+5
-17
lines changed

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,14 @@ def forward(self, x):
180180
return x
181181

182182

183-
class LayerNorm2d(tp.Module):
183+
class LayerNorm2d(tp.LayerNorm):
184184
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
185-
super().__init__()
186-
from nvtripy.frontend.module.parameter import DefaultParameter
187-
188-
self.weight = DefaultParameter((num_channels,), tp.float32)
189-
self.bias = DefaultParameter((num_channels,), tp.float32)
190-
self.eps = eps
185+
super().__init__(num_channels, dtype=tp.float32, eps=eps)
191186

192187
def forward(self, x: tp.Tensor) -> tp.Tensor:
193-
original_dtype = x.dtype
194-
x = tp.cast(x, tp.float32)
195-
u = tp.mean(x, dim=1, keepdim=True)
196-
s = tp.mean((x - u) ** 2, dim=1, keepdim=True)
197-
x = (x - u) / tp.sqrt(s + self.eps)
198-
w = tp.unsqueeze(tp.unsqueeze(self.weight, 1), 2)
199-
b = tp.unsqueeze(tp.unsqueeze(self.bias, 1), 2)
200-
x = w * x + b
201-
x = tp.cast(x, original_dtype)
202-
return x
188+
x = tp.permute(x, (0, 2, 3, 1))
189+
x = super().forward(x)
190+
return tp.permute(x, (0, 3, 1, 2))
203191

204192

205193
def get_activation_fn(activation):

0 commit comments

Comments
 (0)