Skip to content

Commit 92869a3

Browse files
Fixes a dtype mismatch in LayerNorm2D
1 parent b634fd8 commit 92869a3

File tree

1 file changed

+4
-0
lines changed
  • tripy/examples/segment-anything-model-v2/sam2/modeling

1 file changed

+4
-0
lines changed

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,11 @@ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
186186

187187
def forward(self, x: tp.Tensor) -> tp.Tensor:
188188
x = tp.permute(x, (0, 2, 3, 1))
189+
# LayerNorm is always done in float32:
190+
original_dtype = x.dtype
191+
x = tp.cast(x, tp.float32)
189192
x = super().forward(x)
193+
x = tp.cast(x, original_dtype)
190194
return tp.permute(x, (0, 3, 1, 2))
191195

192196

0 commit comments

Comments
 (0)