Skip to content

Commit 10706f2

Browse files
new backbone: ViT pretrained on imagenet with mae loss (#302)
* initial implemenation * load weights from pretrained backbone * fix test
1 parent ecbdac2 commit 10706f2

13 files changed

+339
-46
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
2-
max-line-length = 88
2+
max-line-length = 99
33
# E203: black conflict
44
# E701: black conflict
55
# F821: lot of issues regarding type annotations
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
ViTVisionEncoder
2+
================
3+
4+
.. currentmodule:: lightning_pose.models.backbones.vit_mae
5+
6+
.. autoclass:: ViTVisionEncoder
7+
:show-inheritance:
8+
9+
.. rubric:: Methods Summary
10+
11+
.. autosummary::
12+
13+
~ViTVisionEncoder.forward
14+
15+
.. rubric:: Methods Documentation
16+
17+
.. automethod:: forward

docs/api/lightning_pose.models.backbones.vit_sam.load_sam_vision_encoder_hf.rst

Lines changed: 0 additions & 6 deletions
This file was deleted.

docs/modules/lightning_pose.models.backbones.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ lightning\_pose.models.backbones
44
.. automodapi:: lightning_pose.models.backbones.torchvision
55
:no-inheritance-diagram:
66

7+
.. automodapi:: lightning_pose.models.backbones.vit_mae
8+
:no-inheritance-diagram:
9+
710
.. automodapi:: lightning_pose.models.backbones.vit_sam
811
:no-inheritance-diagram:
912

lightning_pose/models/backbones/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
"efficientnet_b1",
1818
"efficientnet_b2",
1919
"vitb_sam",
20+
"vitb_imagenet",
2021
]
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import math
2+
3+
import torch
4+
import torch.nn as nn
5+
from transformers import ViTMAEConfig, ViTMAEForPreTraining
6+
7+
# to ignore imports for sphix-autoapidoc
8+
__all__ = [
9+
"ViTVisionEncoder",
10+
]
11+
12+
13+
class ViTVisionEncoder(nn.Module):
14+
"""Wrapper around HuggingFace's ViTMAE Vision Encoder."""
15+
16+
def __init__(
17+
self,
18+
model_name: str = "facebook/vit-mae-base",
19+
finetune_img_size: int = 256,
20+
):
21+
super().__init__()
22+
23+
if model_name == "facebook/vit-mae-base":
24+
img_size = 224
25+
config = {
26+
'hidden_size': 768,
27+
'num_hidden_layers': 12,
28+
'num_attention_heads': 12,
29+
'intermediate_size': 3072,
30+
'hidden_act': "gelu",
31+
'hidden_dropout_prob': 0.0,
32+
'attention_probs_dropout_prob': 0.0,
33+
'initializer_range': 0.02,
34+
'layer_norm_eps': 1.e-12,
35+
'image_size': img_size, # usually 224
36+
'patch_size': 16, # default is 16, we use large patch size
37+
'num_channels': 3, # 3 for RGB
38+
'qkv_bias': True,
39+
'decoder_num_attention_heads': 16,
40+
'decoder_hidden_size': 512,
41+
'decoder_num_hidden_layers': 8,
42+
'decoder_intermediate_size': 2048,
43+
'mask_ratio': 0, # 0 for no masking, usually 0.75 (MAE)
44+
'norm_pix_loss': False,
45+
}
46+
else:
47+
raise NotImplementedError(f"{model_name} is not a valid ViTVisionEncoder model name")
48+
49+
# Load the full ViT model and extract encoder
50+
self.config = ViTMAEConfig(**config)
51+
self.vision_encoder = ViTMAE.from_pretrained(model_name)
52+
del self.vision_encoder.decoder # remove the decoder from the vit_mae
53+
self.vision_encoder.config.mask_ratio = 0
54+
55+
# Store size information
56+
self.img_size = img_size
57+
self.finetune_img_size = finetune_img_size
58+
self.patch_size = self.vision_encoder.config.patch_size
59+
60+
# Store original positional embeddings for potential resizing
61+
self.original_pos_embed = None
62+
if hasattr(self.vision_encoder.vit.embeddings, 'position_embeddings'):
63+
self.original_pos_embed = \
64+
self.vision_encoder.vit.embeddings.position_embeddings.clone()
65+
66+
# Check if we need to resize positional embeddings
67+
if (
68+
self.finetune_img_size != img_size
69+
and hasattr(self.vision_encoder.vit.embeddings, 'position_embeddings')
70+
and self.vision_encoder.vit.embeddings.position_embeddings is not None
71+
):
72+
# Resize positional embeddings if needed
73+
print(
74+
f"Finetune image size ({finetune_img_size}) does not match model size ({img_size})"
75+
f" - recomputing position embeddings"
76+
)
77+
self._resize_pos_embed()
78+
79+
# Bypass size check entirely
80+
self._bypass_size_check()
81+
82+
def _bypass_size_check(self):
83+
"""Completely bypass the size check in patch embedding"""
84+
85+
def no_size_check_forward(pixel_values, interpolate_pos_encoding: bool = False):
86+
batch_size, num_channels, height, width = pixel_values.shape
87+
88+
# Only check channel dimension
89+
if num_channels != self.vision_encoder.vit.config.num_channels:
90+
raise ValueError(
91+
"Make sure that the channel dimension of the pixel values match with the one "
92+
"set in the configuration."
93+
)
94+
95+
# Skip size check entirely - just do the convolution
96+
embeddings = self.vision_encoder.vit.embeddings.patch_embeddings.projection(
97+
pixel_values
98+
).flatten(2).transpose(1, 2)
99+
return embeddings
100+
101+
# Replace the forward method
102+
self.vision_encoder.vit.embeddings.patch_embeddings.forward = no_size_check_forward
103+
print("Bypassed all size checking in embeddings")
104+
105+
def forward(self, x: torch.Tensor) -> torch.Tensor:
106+
N = x.shape[0]
107+
if self.config.num_channels == 1:
108+
# adjust input channels to 1
109+
x = x[:, 0, ...].unsqueeze(1)
110+
outputs = self.vision_encoder(
111+
pixel_values=x,
112+
return_latent=True,
113+
)
114+
# skip the cls token
115+
outputs = outputs[:, 1:, ...] # [N, S, D]
116+
# change the shape to [N, H, W, D] -> [N, D, H, W]
117+
S = outputs.shape[1]
118+
H, W = math.isqrt(S), math.isqrt(S)
119+
outputs = outputs.reshape(N, H, W, -1).permute(0, 3, 1, 2)
120+
return outputs
121+
122+
def _resize_pos_embed(self):
123+
"""Resize positional embeddings for different input sizes"""
124+
125+
if self.original_pos_embed is None:
126+
return
127+
128+
# Calculate target size
129+
old_size = self.img_size // self.patch_size # 224 // 16 = 14
130+
new_size = self.finetune_img_size // self.patch_size # 128 // 16 = 8
131+
132+
if old_size == new_size:
133+
return
134+
135+
print(f"Resizing pos_embed from {old_size}x{old_size} to {new_size}x{new_size}")
136+
137+
# Original pos_embed format: [1, H*W + 1, C]
138+
pos_embed = self.original_pos_embed # [1, 197, 768] for 224x224 input
139+
140+
# Separate CLS token embedding from spatial embeddings
141+
cls_token_embed = pos_embed[:, 0:1, :] # [1, 1, 768] - CLS token
142+
spatial_embeddings = pos_embed[:, 1:, :] # [1, 196, 768] - spatial patches
143+
144+
# Reshape spatial embeddings to 2D spatial format
145+
# [1, H*W, C] -> [1, H, W, C]
146+
batch_size, num_patches, embed_dim = spatial_embeddings.shape
147+
spatial_2d = spatial_embeddings.reshape(batch_size, old_size, old_size, embed_dim)
148+
149+
# Convert to [1, C, H, W] for interpolation
150+
spatial_2d = spatial_2d.permute(0, 3, 1, 2) # [1, 768, 14, 14]
151+
152+
# Resize using interpolation
153+
spatial_resized = nn.functional.interpolate(
154+
spatial_2d,
155+
size=(new_size, new_size), # (8, 8)
156+
mode='bicubic',
157+
antialias=True,
158+
) # [1, 768, 8, 8]
159+
160+
# Convert back to sequence format
161+
# [1, C, H, W] -> [1, H, W, C] -> [1, H*W, C]
162+
spatial_resized = spatial_resized.permute(0, 2, 3, 1) # [1, 8, 8, 768]
163+
spatial_resized = spatial_resized.reshape(batch_size, new_size * new_size, embed_dim)
164+
165+
# Concatenate CLS token back at the beginning
166+
pos_embed_final = torch.cat([cls_token_embed, spatial_resized], dim=1) # [1, 65, 768]
167+
168+
# Update the position embeddings
169+
self.vision_encoder.vit.embeddings.position_embeddings = nn.Parameter(pos_embed_final)
170+
171+
172+
class ViTMAE(ViTMAEForPreTraining):
173+
174+
def forward(
175+
self,
176+
pixel_values,
177+
noise=None,
178+
head_mask=None,
179+
output_attentions=None,
180+
output_hidden_states=None,
181+
return_dict=None,
182+
interpolate_pos_encoding=False,
183+
return_latent=False,
184+
return_recon=False,
185+
):
186+
# Setting default for return_dict based on the configuration
187+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
188+
if (self.training and self.config.mask_ratio > 0) or return_recon:
189+
outputs = self.vit(
190+
pixel_values,
191+
noise=noise,
192+
head_mask=head_mask,
193+
output_attentions=output_attentions,
194+
output_hidden_states=output_hidden_states,
195+
return_dict=return_dict,
196+
)
197+
latent = outputs.last_hidden_state
198+
else:
199+
# use for fine-tuning, or inference
200+
# mask_ratio = 0
201+
embedding_output, mask, ids_restore = self.vit.embeddings(pixel_values)
202+
embedding_output_ = embedding_output[:, 1:, :] # no cls token
203+
# unshuffle the embedding output
204+
embedding_output_ = torch.gather(
205+
embedding_output_,
206+
dim=1,
207+
index=ids_restore.unsqueeze(-1).repeat(
208+
1, 1, embedding_output_.shape[2]
209+
).to(embedding_output_.device))
210+
# add cls token back
211+
embedding_output = torch.cat((embedding_output[:, :1, :], embedding_output_), dim=1)
212+
encoder_outputs = self.vit.encoder(
213+
embedding_output,
214+
return_dict=return_dict,
215+
)
216+
sequence_output = encoder_outputs[0]
217+
latent = self.vit.layernorm(sequence_output)
218+
if not return_latent:
219+
# return the cls token and 0 loss if not return_latent
220+
return latent[:, 0], 0
221+
222+
if return_latent:
223+
return latent
224+
225+
# extract cls latent
226+
cls_latent = latent[:, 0] # shape (batch_size, hidden_size)
227+
ids_restore = outputs.ids_restore
228+
mask = outputs.mask
229+
230+
decoder_outputs = self.decoder(latent, ids_restore)
231+
# shape (batch_size, num_patches, patch_size*patch_size*num_channels)
232+
logits = decoder_outputs.logits
233+
# print(decoder_outputs.keys())
234+
loss = self.forward_loss(pixel_values, logits, mask)
235+
if return_recon:
236+
return cls_latent, loss, logits
237+
238+
return cls_latent, loss

lightning_pose/models/backbones/vit_sam.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
import torch.nn.functional as F
66
from transformers import SamModel
77

8+
# to ignore imports for sphix-autoapidoc
9+
__all__ = [
10+
"SamVisionEncoderHF",
11+
]
12+
813

914
class SamVisionEncoderHF(nn.Module):
1015
"""Wrapper around HuggingFace's SAM Vision Encoder."""
@@ -48,6 +53,11 @@ def __init__(
4853
# Bypass size check entirely
4954
self._bypass_size_check()
5055

56+
# Disable relative positional encoding in SAM
57+
for layer in self.vision_encoder.layers:
58+
if hasattr(layer.attn, "use_rel_pos"):
59+
layer.attn.use_rel_pos = False
60+
5161
def _bypass_size_check(self):
5262
"""Completely bypass the size check in patch embedding"""
5363

@@ -147,31 +157,3 @@ def _resize_pos_embed(self):
147157

148158
# Update the vision encoder's positional embeddings
149159
self.vision_encoder.pos_embed = nn.Parameter(pos_embed_final)
150-
151-
152-
def load_sam_vision_encoder_hf(
153-
model_name: str = "facebook/sam-vit-base",
154-
finetune_image_size: int = 1024,
155-
image_size: int = 1024
156-
):
157-
"""Load SAM vision encoder from HuggingFace.
158-
159-
Args:
160-
model_name: HuggingFace model name
161-
(facebook/sam-vit-base, facebook/sam-vit-large, facebook/sam-vit-huge)
162-
finetune_image_size: Target image size for fine-tuning
163-
image_size: Original image size (usually 1024 for SAM)
164-
165-
Returns:
166-
SamVisionEncoderHF instance
167-
168-
"""
169-
170-
# Create the wrapper
171-
encoder = SamVisionEncoderHF(
172-
model_name=model_name,
173-
finetune_img_size=finetune_image_size,
174-
img_size=image_size
175-
)
176-
177-
return encoder

0 commit comments

Comments
 (0)