Skip to content

Commit 345c099

Browse files
committed
Update LayerNorm2D to use TRT normalization API
1 parent 3ae73ea commit 345c099

File tree

1 file changed

+6
-10
lines changed
  • tripy/examples/segment-anything-model-v2/sam2/modeling

1 file changed

+6
-10
lines changed

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,14 @@ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
188188
self.weight = DefaultParameter((num_channels,), tp.float32)
189189
self.bias = DefaultParameter((num_channels,), tp.float32)
190190
self.eps = eps
191+
self.num_channels = num_channels
191192

192193
def forward(self, x: tp.Tensor) -> tp.Tensor:
193-
original_dtype = x.dtype
194-
x = tp.cast(x, tp.float32)
195-
u = tp.mean(x, dim=1, keepdim=True)
196-
s = tp.mean((x - u) ** 2, dim=1, keepdim=True)
197-
x = (x - u) / tp.sqrt(s + self.eps)
198-
w = tp.unsqueeze(tp.unsqueeze(self.weight, 1), 2)
199-
b = tp.unsqueeze(tp.unsqueeze(self.bias, 1), 2)
200-
x = w * x + b
201-
x = tp.cast(x, original_dtype)
202-
return x
194+
x_perm = tp.permute(x, (0, 2, 3, 1))
195+
ln = tp.LayerNorm(self.num_channels, dtype=x.dtype, eps=self.eps)
196+
ln.weight = self.weight
197+
ln.bias = self.bias
198+
return tp.permute(ln(x_perm), (0, 3, 1, 2))
203199

204200

205201
def get_activation_fn(activation):

0 commit comments

Comments
 (0)