|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from transformers import ViTMAEConfig, ViTMAEForPreTraining |
| 6 | + |
| 7 | +# to ignore imports for sphix-autoapidoc |
| 8 | +__all__ = [ |
| 9 | + "ViTVisionEncoder", |
| 10 | +] |
| 11 | + |
| 12 | + |
| 13 | +class ViTVisionEncoder(nn.Module): |
| 14 | + """Wrapper around HuggingFace's ViTMAE Vision Encoder.""" |
| 15 | + |
| 16 | + def __init__( |
| 17 | + self, |
| 18 | + model_name: str = "facebook/vit-mae-base", |
| 19 | + finetune_img_size: int = 256, |
| 20 | + ): |
| 21 | + super().__init__() |
| 22 | + |
| 23 | + if model_name == "facebook/vit-mae-base": |
| 24 | + img_size = 224 |
| 25 | + config = { |
| 26 | + 'hidden_size': 768, |
| 27 | + 'num_hidden_layers': 12, |
| 28 | + 'num_attention_heads': 12, |
| 29 | + 'intermediate_size': 3072, |
| 30 | + 'hidden_act': "gelu", |
| 31 | + 'hidden_dropout_prob': 0.0, |
| 32 | + 'attention_probs_dropout_prob': 0.0, |
| 33 | + 'initializer_range': 0.02, |
| 34 | + 'layer_norm_eps': 1.e-12, |
| 35 | + 'image_size': img_size, # usually 224 |
| 36 | + 'patch_size': 16, # default is 16, we use large patch size |
| 37 | + 'num_channels': 3, # 3 for RGB |
| 38 | + 'qkv_bias': True, |
| 39 | + 'decoder_num_attention_heads': 16, |
| 40 | + 'decoder_hidden_size': 512, |
| 41 | + 'decoder_num_hidden_layers': 8, |
| 42 | + 'decoder_intermediate_size': 2048, |
| 43 | + 'mask_ratio': 0, # 0 for no masking, usually 0.75 (MAE) |
| 44 | + 'norm_pix_loss': False, |
| 45 | + } |
| 46 | + else: |
| 47 | + raise NotImplementedError(f"{model_name} is not a valid ViTVisionEncoder model name") |
| 48 | + |
| 49 | + # Load the full ViT model and extract encoder |
| 50 | + self.config = ViTMAEConfig(**config) |
| 51 | + self.vision_encoder = ViTMAE.from_pretrained(model_name) |
| 52 | + del self.vision_encoder.decoder # remove the decoder from the vit_mae |
| 53 | + self.vision_encoder.config.mask_ratio = 0 |
| 54 | + |
| 55 | + # Store size information |
| 56 | + self.img_size = img_size |
| 57 | + self.finetune_img_size = finetune_img_size |
| 58 | + self.patch_size = self.vision_encoder.config.patch_size |
| 59 | + |
| 60 | + # Store original positional embeddings for potential resizing |
| 61 | + self.original_pos_embed = None |
| 62 | + if hasattr(self.vision_encoder.vit.embeddings, 'position_embeddings'): |
| 63 | + self.original_pos_embed = \ |
| 64 | + self.vision_encoder.vit.embeddings.position_embeddings.clone() |
| 65 | + |
| 66 | + # Check if we need to resize positional embeddings |
| 67 | + if ( |
| 68 | + self.finetune_img_size != img_size |
| 69 | + and hasattr(self.vision_encoder.vit.embeddings, 'position_embeddings') |
| 70 | + and self.vision_encoder.vit.embeddings.position_embeddings is not None |
| 71 | + ): |
| 72 | + # Resize positional embeddings if needed |
| 73 | + print( |
| 74 | + f"Finetune image size ({finetune_img_size}) does not match model size ({img_size})" |
| 75 | + f" - recomputing position embeddings" |
| 76 | + ) |
| 77 | + self._resize_pos_embed() |
| 78 | + |
| 79 | + # Bypass size check entirely |
| 80 | + self._bypass_size_check() |
| 81 | + |
| 82 | + def _bypass_size_check(self): |
| 83 | + """Completely bypass the size check in patch embedding""" |
| 84 | + |
| 85 | + def no_size_check_forward(pixel_values, interpolate_pos_encoding: bool = False): |
| 86 | + batch_size, num_channels, height, width = pixel_values.shape |
| 87 | + |
| 88 | + # Only check channel dimension |
| 89 | + if num_channels != self.vision_encoder.vit.config.num_channels: |
| 90 | + raise ValueError( |
| 91 | + "Make sure that the channel dimension of the pixel values match with the one " |
| 92 | + "set in the configuration." |
| 93 | + ) |
| 94 | + |
| 95 | + # Skip size check entirely - just do the convolution |
| 96 | + embeddings = self.vision_encoder.vit.embeddings.patch_embeddings.projection( |
| 97 | + pixel_values |
| 98 | + ).flatten(2).transpose(1, 2) |
| 99 | + return embeddings |
| 100 | + |
| 101 | + # Replace the forward method |
| 102 | + self.vision_encoder.vit.embeddings.patch_embeddings.forward = no_size_check_forward |
| 103 | + print("Bypassed all size checking in embeddings") |
| 104 | + |
| 105 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 106 | + N = x.shape[0] |
| 107 | + if self.config.num_channels == 1: |
| 108 | + # adjust input channels to 1 |
| 109 | + x = x[:, 0, ...].unsqueeze(1) |
| 110 | + outputs = self.vision_encoder( |
| 111 | + pixel_values=x, |
| 112 | + return_latent=True, |
| 113 | + ) |
| 114 | + # skip the cls token |
| 115 | + outputs = outputs[:, 1:, ...] # [N, S, D] |
| 116 | + # change the shape to [N, H, W, D] -> [N, D, H, W] |
| 117 | + S = outputs.shape[1] |
| 118 | + H, W = math.isqrt(S), math.isqrt(S) |
| 119 | + outputs = outputs.reshape(N, H, W, -1).permute(0, 3, 1, 2) |
| 120 | + return outputs |
| 121 | + |
| 122 | + def _resize_pos_embed(self): |
| 123 | + """Resize positional embeddings for different input sizes""" |
| 124 | + |
| 125 | + if self.original_pos_embed is None: |
| 126 | + return |
| 127 | + |
| 128 | + # Calculate target size |
| 129 | + old_size = self.img_size // self.patch_size # 224 // 16 = 14 |
| 130 | + new_size = self.finetune_img_size // self.patch_size # 128 // 16 = 8 |
| 131 | + |
| 132 | + if old_size == new_size: |
| 133 | + return |
| 134 | + |
| 135 | + print(f"Resizing pos_embed from {old_size}x{old_size} to {new_size}x{new_size}") |
| 136 | + |
| 137 | + # Original pos_embed format: [1, H*W + 1, C] |
| 138 | + pos_embed = self.original_pos_embed # [1, 197, 768] for 224x224 input |
| 139 | + |
| 140 | + # Separate CLS token embedding from spatial embeddings |
| 141 | + cls_token_embed = pos_embed[:, 0:1, :] # [1, 1, 768] - CLS token |
| 142 | + spatial_embeddings = pos_embed[:, 1:, :] # [1, 196, 768] - spatial patches |
| 143 | + |
| 144 | + # Reshape spatial embeddings to 2D spatial format |
| 145 | + # [1, H*W, C] -> [1, H, W, C] |
| 146 | + batch_size, num_patches, embed_dim = spatial_embeddings.shape |
| 147 | + spatial_2d = spatial_embeddings.reshape(batch_size, old_size, old_size, embed_dim) |
| 148 | + |
| 149 | + # Convert to [1, C, H, W] for interpolation |
| 150 | + spatial_2d = spatial_2d.permute(0, 3, 1, 2) # [1, 768, 14, 14] |
| 151 | + |
| 152 | + # Resize using interpolation |
| 153 | + spatial_resized = nn.functional.interpolate( |
| 154 | + spatial_2d, |
| 155 | + size=(new_size, new_size), # (8, 8) |
| 156 | + mode='bicubic', |
| 157 | + antialias=True, |
| 158 | + ) # [1, 768, 8, 8] |
| 159 | + |
| 160 | + # Convert back to sequence format |
| 161 | + # [1, C, H, W] -> [1, H, W, C] -> [1, H*W, C] |
| 162 | + spatial_resized = spatial_resized.permute(0, 2, 3, 1) # [1, 8, 8, 768] |
| 163 | + spatial_resized = spatial_resized.reshape(batch_size, new_size * new_size, embed_dim) |
| 164 | + |
| 165 | + # Concatenate CLS token back at the beginning |
| 166 | + pos_embed_final = torch.cat([cls_token_embed, spatial_resized], dim=1) # [1, 65, 768] |
| 167 | + |
| 168 | + # Update the position embeddings |
| 169 | + self.vision_encoder.vit.embeddings.position_embeddings = nn.Parameter(pos_embed_final) |
| 170 | + |
| 171 | + |
| 172 | +class ViTMAE(ViTMAEForPreTraining): |
| 173 | + |
| 174 | + def forward( |
| 175 | + self, |
| 176 | + pixel_values, |
| 177 | + noise=None, |
| 178 | + head_mask=None, |
| 179 | + output_attentions=None, |
| 180 | + output_hidden_states=None, |
| 181 | + return_dict=None, |
| 182 | + interpolate_pos_encoding=False, |
| 183 | + return_latent=False, |
| 184 | + return_recon=False, |
| 185 | + ): |
| 186 | + # Setting default for return_dict based on the configuration |
| 187 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 188 | + if (self.training and self.config.mask_ratio > 0) or return_recon: |
| 189 | + outputs = self.vit( |
| 190 | + pixel_values, |
| 191 | + noise=noise, |
| 192 | + head_mask=head_mask, |
| 193 | + output_attentions=output_attentions, |
| 194 | + output_hidden_states=output_hidden_states, |
| 195 | + return_dict=return_dict, |
| 196 | + ) |
| 197 | + latent = outputs.last_hidden_state |
| 198 | + else: |
| 199 | + # use for fine-tuning, or inference |
| 200 | + # mask_ratio = 0 |
| 201 | + embedding_output, mask, ids_restore = self.vit.embeddings(pixel_values) |
| 202 | + embedding_output_ = embedding_output[:, 1:, :] # no cls token |
| 203 | + # unshuffle the embedding output |
| 204 | + embedding_output_ = torch.gather( |
| 205 | + embedding_output_, |
| 206 | + dim=1, |
| 207 | + index=ids_restore.unsqueeze(-1).repeat( |
| 208 | + 1, 1, embedding_output_.shape[2] |
| 209 | + ).to(embedding_output_.device)) |
| 210 | + # add cls token back |
| 211 | + embedding_output = torch.cat((embedding_output[:, :1, :], embedding_output_), dim=1) |
| 212 | + encoder_outputs = self.vit.encoder( |
| 213 | + embedding_output, |
| 214 | + return_dict=return_dict, |
| 215 | + ) |
| 216 | + sequence_output = encoder_outputs[0] |
| 217 | + latent = self.vit.layernorm(sequence_output) |
| 218 | + if not return_latent: |
| 219 | + # return the cls token and 0 loss if not return_latent |
| 220 | + return latent[:, 0], 0 |
| 221 | + |
| 222 | + if return_latent: |
| 223 | + return latent |
| 224 | + |
| 225 | + # extract cls latent |
| 226 | + cls_latent = latent[:, 0] # shape (batch_size, hidden_size) |
| 227 | + ids_restore = outputs.ids_restore |
| 228 | + mask = outputs.mask |
| 229 | + |
| 230 | + decoder_outputs = self.decoder(latent, ids_restore) |
| 231 | + # shape (batch_size, num_patches, patch_size*patch_size*num_channels) |
| 232 | + logits = decoder_outputs.logits |
| 233 | + # print(decoder_outputs.keys()) |
| 234 | + loss = self.forward_loss(pixel_values, logits, mask) |
| 235 | + if return_recon: |
| 236 | + return cls_latent, loss, logits |
| 237 | + |
| 238 | + return cls_latent, loss |
0 commit comments