Skip to content

Commit 6c3162a

Browse files
Removes unnecessary casting around normalization layers
1 parent 92869a3 commit 6c3162a

File tree

7 files changed

+35
-49
lines changed

7 files changed

+35
-49
lines changed

tripy/examples/segment-anything-model-v2/sam2/build_sam.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def get_component_configs(model, cfg):
8484
# TODO (#594): Remove this hack once we are able to pass in DimensionSizes directly:
8585
tp.InputInfo(((4, 16, 64),), tp.int32),
8686
],
87-
"skip_dtype_convert": ["ln", "norm"],
87+
"skip_dtype_convert": [],
8888
},
8989
"sam_mask_decoder_false": {
9090
"enabled": True,
@@ -118,7 +118,7 @@ def get_component_configs(model, cfg):
118118
dtype=getattr(tp, model_precision),
119119
), # high_res_features_2
120120
],
121-
"skip_dtype_convert": ["ln", "norm", "output_upscaling.1"],
121+
"skip_dtype_convert": [],
122122
},
123123
"sam_mask_decoder_true": {
124124
"enabled": True,
@@ -152,7 +152,7 @@ def get_component_configs(model, cfg):
152152
dtype=getattr(tp, model_precision),
153153
), # high_res_features_2
154154
],
155-
"skip_dtype_convert": ["ln", "norm", "output_upscaling.1"],
155+
"skip_dtype_convert": [],
156156
"skip_load_state_dict": True,
157157
},
158158
"sam_mask_decoder.conv_s0": {
@@ -190,8 +190,7 @@ def get_component_configs(model, cfg):
190190
tp.InputInfo((batch, num_obj, 1024, 1024), getattr(tp, model_precision)),
191191
True,
192192
],
193-
"skip_dtype_convert": ["ln", "norm"]
194-
+ [f"encoder.{i}.{param}" for i in range(1, 40, 3) for param in ("weight", "bias")],
193+
"skip_dtype_convert": [],
195194
},
196195
"sam_prompt_encoder": {
197196
"enabled": True,
@@ -230,7 +229,7 @@ def get_component_configs(model, cfg):
230229
dtype=getattr(tp, model_precision),
231230
),
232231
],
233-
"skip_dtype_convert": ["norm"],
232+
"skip_dtype_convert": [],
234233
"special_key_loading": lambda key: (
235234
# If it's a neck.convs key that contains 'conv.'
236235
# neck.convs.0.conv.weight -> neck.convs.0.weight

tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/hieradet.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
super().__init__()
114114

115115
if isinstance(norm_layer, str):
116-
norm_layer = partial(getattr(tp, norm_layer), eps=1e-6)
116+
norm_layer = partial(getattr(tp, norm_layer), eps=1e-6, dtype=dtype)
117117

118118
self.dim = dim
119119
self.dim_out = dim_out
@@ -149,15 +149,8 @@ def __init__(
149149
self.proj = tp.Linear(dim, dim_out, dtype=dtype)
150150

151151
def forward(self, x):
152-
153-
def call_norm(x, norm):
154-
x_dtype = x.dtype
155-
x = tp.cast(x, tp.float32)
156-
x = norm(x)
157-
return tp.cast(x, x_dtype)
158-
159152
shortcut = x # B, H, W, C
160-
x = call_norm(x, self.norm1)
153+
x = self.norm1(x)
161154

162155
# Skip connection
163156
if self.dim != self.dim_out:
@@ -189,7 +182,7 @@ def mod_int(x, y):
189182

190183
x = shortcut + x
191184
# MLP
192-
t = call_norm(x, self.norm2)
185+
t = self.norm2(x)
193186
x = x + self.mlp(t)
194187
return x
195188

tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def __init__(
5454
self.linear1 = tp.Linear(d_model, dim_feedforward, dtype=self.dtype)
5555
self.linear2 = tp.Linear(dim_feedforward, d_model, dtype=self.dtype)
5656

57-
self.norm1 = tp.LayerNorm(d_model)
58-
self.norm2 = tp.LayerNorm(d_model)
59-
self.norm3 = tp.LayerNorm(d_model)
57+
self.norm1 = tp.LayerNorm(d_model, dtype=self.dtype)
58+
self.norm2 = tp.LayerNorm(d_model, dtype=self.dtype)
59+
self.norm3 = tp.LayerNorm(d_model, dtype=self.dtype)
6060

6161
self.activation_str = activation
6262
self.activation = get_activation_fn(activation)
@@ -68,15 +68,15 @@ def __init__(
6868

6969
def _forward_sa(self, tgt, query_pos):
7070
# Self-Attention
71-
tgt2 = tp.cast(self.norm1(tp.cast(tgt, self.norm1.dtype)), self.dtype)
71+
tgt2 = self.norm1(tgt)
7272
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
7373
tgt2 = self.self_attn(q, k, v=tgt2, num_k_exclude_rope=0)
7474
tgt = tgt + tgt2
7575
return tgt
7676

7777
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
7878
# Cross-Attention
79-
tgt2 = tp.cast(self.norm2(tp.cast(tgt, self.norm2.dtype)), self.dtype)
79+
tgt2 = self.norm2(tgt)
8080

8181
tgt2 = self.cross_attn_image(
8282
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
@@ -100,7 +100,7 @@ def forward(
100100
tgt = self._forward_sa(tgt, query_pos)
101101
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
102102
# MLP
103-
tgt2 = tp.cast(self.norm3(tp.cast(tgt, self.norm3.dtype)), self.dtype)
103+
tgt2 = self.norm3(tgt)
104104

105105
tgt2 = self.linear2(self.activation(self.linear1(tgt2)))
106106
tgt = tgt + tgt2
@@ -137,12 +137,13 @@ def __init__(
137137
dtype="float32",
138138
):
139139
super().__init__()
140+
self.dtype = getattr(tp, dtype)
141+
140142
self.d_model = d_model
141143
self.num_layers = num_layers
142-
self.norm = tp.LayerNorm(d_model)
144+
self.norm = tp.LayerNorm(d_model, self.dtype)
143145
self.pos_enc_at_input = pos_enc_at_input
144146
self.batch_first = batch_first
145-
self.dtype = getattr(tp, dtype)
146147
self.layers = []
147148
for _ in range(num_layers):
148149
self_attn = RoPEAttention(
@@ -215,7 +216,7 @@ def forward(
215216
**kwds,
216217
)
217218

218-
normed_output = tp.cast(self.norm(tp.cast(output, self.norm.dtype)), self.dtype)
219+
normed_output = self.norm(output)
219220

220221
if self.batch_first:
221222
# Convert back to seq first

tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
dtype=self.dtype,
7272
)
7373
)
74-
self.encoder.append(LayerNorm2d(mask_out_chans))
74+
self.encoder.append(LayerNorm2d(mask_out_chans, dtype=self.dtype))
7575
self.encoder.append(activation)
7676
mask_in_chans = mask_out_chans
7777

@@ -108,7 +108,7 @@ def __init__(
108108
groups=dim if use_dwconv else 1,
109109
dtype=self.dtype,
110110
) # depthwise conv
111-
self.norm = LayerNorm2d(dim, eps=1e-6)
111+
self.norm = LayerNorm2d(dim, eps=1e-6, dtype=self.dtype)
112112
self.pwconv1 = tp.Linear(dim, 4 * dim, dtype=self.dtype) # pointwise/1x1 convs, implemented with linear layers
113113
self.act = tp.gelu
114114
self.pwconv2 = tp.Linear(4 * dim, dim, dtype=self.dtype)

tripy/examples/segment-anything-model-v2/sam2/modeling/sam/mask_decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
stride=(2, 2),
100100
dtype=dtype,
101101
),
102-
LayerNorm2d(transformer_dim // 4),
102+
LayerNorm2d(transformer_dim // 4, dtype=dtype),
103103
Dummy(), # Accounts for Dropout layer, needed for weight loading
104104
tp.ConvTranspose(
105105
transformer_dim // 4,
@@ -289,13 +289,13 @@ def predict_masks(
289289

290290
if not self.use_high_res_features:
291291
dc1, ln1, _, dc2, _ = self.output_upscaling
292-
post_ln1 = tp.cast(ln1(tp.cast(dc1(src), tp.float32)), src.dtype)
292+
post_ln1 = ln1(dc1(src))
293293
upscaled_embedding = act2(dc2(act1(post_ln1)))
294294
# upscaled_embedding = act2(dc2(act1(ln1(dc1(src)))))
295295
else:
296296
dc1, ln1, _, dc2, _ = self.output_upscaling
297297
feat_s0, feat_s1 = high_res_features_1, high_res_features_2
298-
post_ln1 = tp.cast(ln1(tp.cast(dc1(src) + feat_s1, tp.float32)), src.dtype)
298+
post_ln1 = ln1(dc1(src) + feat_s1)
299299
upscaled_embedding = act1(post_ln1)
300300
# upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
301301
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)

tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
downsample_rate=attention_downsample_rate,
8282
dtype=dtype,
8383
)
84-
self.norm_final_attn = tp.LayerNorm(embedding_dim)
84+
self.norm_final_attn = tp.LayerNorm(embedding_dim, dtype=dtype)
8585

8686
def forward(
8787
self,
@@ -134,10 +134,7 @@ def forward_impl(
134134
k = keys + image_pe
135135
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
136136
queries = queries + attn_out
137-
queries = tp.cast(
138-
self.norm_final_attn(tp.cast(queries, self.norm_final_attn.dtype)),
139-
queries.dtype,
140-
)
137+
queries = self.norm_final_attn(queries)
141138
# queries = self.norm_final_attn(queries)
142139

143140
return queries, keys
@@ -170,15 +167,15 @@ def __init__(
170167
"""
171168
super().__init__()
172169
self.self_attn = Attention(embedding_dim, num_heads, dtype=dtype)
173-
self.norm1 = tp.LayerNorm(embedding_dim)
170+
self.norm1 = tp.LayerNorm(embedding_dim, dtype=dtype)
174171

175172
self.cross_attn_token_to_image = Attention(
176173
embedding_dim,
177174
num_heads,
178175
downsample_rate=attention_downsample_rate,
179176
dtype=dtype,
180177
)
181-
self.norm2 = tp.LayerNorm(embedding_dim)
178+
self.norm2 = tp.LayerNorm(embedding_dim, dtype=dtype)
182179

183180
self.mlp = MLP(
184181
embedding_dim,
@@ -188,9 +185,9 @@ def __init__(
188185
activation=activation,
189186
dtype=dtype,
190187
)
191-
self.norm3 = tp.LayerNorm(embedding_dim)
188+
self.norm3 = tp.LayerNorm(embedding_dim, dtype=dtype)
192189

193-
self.norm4 = tp.LayerNorm(embedding_dim)
190+
self.norm4 = tp.LayerNorm(embedding_dim, dtype=dtype)
194191
self.cross_attn_image_to_token = Attention(
195192
embedding_dim,
196193
num_heads,
@@ -212,29 +209,29 @@ def forward_impl(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe:
212209
attn_out = self.self_attn(q=q, k=q, v=queries)
213210
queries = queries + attn_out
214211

215-
queries = tp.cast(self.norm1(tp.cast(queries, self.norm1.dtype)), queries.dtype)
212+
queries = self.norm1(queries)
216213

217214
# Cross attention block, tokens attending to image embedding
218215
q = queries + query_pe
219216
k = keys + key_pe
220217
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
221218
queries = queries + attn_out
222219

223-
queries = tp.cast(self.norm2(tp.cast(queries, self.norm2.dtype)), queries.dtype)
220+
queries = self.norm2(queries)
224221
# queries = self.norm2(queries)
225222

226223
# MLP block
227224
mlp_out = self.mlp(queries)
228225
queries = queries + mlp_out
229-
queries = tp.cast(self.norm3(tp.cast(queries, self.norm3.dtype)), queries.dtype)
226+
queries = self.norm3(queries)
230227
# queries = self.norm3(queries)
231228

232229
# Cross attention block, image embedding attending to tokens
233230
q = queries + query_pe
234231
k = keys + key_pe
235232
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
236233
keys = keys + attn_out
237-
keys = tp.cast(self.norm4(tp.cast(keys, self.norm4.dtype)), keys.dtype)
234+
keys = self.norm4(keys)
238235
# keys = self.norm4(keys)
239236

240237
return queries, keys

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,12 @@ def forward(self, x):
181181

182182

183183
class LayerNorm2d(tp.LayerNorm):
184-
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
185-
super().__init__(num_channels, dtype=tp.float32, eps=eps)
184+
def __init__(self, num_channels: int, eps: float = 1e-6, dtype: tp.dtype = tp.float32) -> None:
185+
super().__init__(num_channels, dtype=dtype, eps=eps)
186186

187187
def forward(self, x: tp.Tensor) -> tp.Tensor:
188188
x = tp.permute(x, (0, 2, 3, 1))
189-
# LayerNorm is always done in float32:
190-
original_dtype = x.dtype
191-
x = tp.cast(x, tp.float32)
192189
x = super().forward(x)
193-
x = tp.cast(x, original_dtype)
194190
return tp.permute(x, (0, 3, 1, 2))
195191

196192

0 commit comments

Comments
 (0)