|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from transformers import PreTrainedModel |
| 4 | +from transformers.modeling_outputs import CausalLMOutput |
| 5 | +from .configuration_model import ModelConfig |
| 6 | + |
| 7 | +class ModelLM(PreTrainedModel): |
| 8 | + config_class = ModelConfig |
| 9 | + base_model_prefix = "backbone" |
| 10 | + ## Use the same tensor for input embeddings and output embeddings |
| 11 | + _tied_weights_keys = ["lm_head.weight", "backbone.embed.weight"] |
| 12 | + |
| 13 | + |
| 14 | + def __init__(self, config: ModelConfig): |
| 15 | + super().__init__(config) |
| 16 | + |
| 17 | + self.backbone = nn.Module() |
| 18 | + self.backbone.embed = nn.Embedding(config.vocab_size, config.hidden_size) |
| 19 | + self.backbone.mlp = nn.Sequential( |
| 20 | + nn.Linear(config.hidden_size, config.hidden_size), |
| 21 | + nn.Tanh(), |
| 22 | + ) |
| 23 | + |
| 24 | + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| 25 | + |
| 26 | + self.post_init() |
| 27 | + |
| 28 | + def get_input_embeddings(self): |
| 29 | + return self.backbone.embed |
| 30 | + |
| 31 | + def set_input_embeddings(self, value): |
| 32 | + self.backbone.embed = value |
| 33 | + |
| 34 | + def get_output_embeddings(self): |
| 35 | + return self.lm_head |
| 36 | + |
| 37 | + def set_output_embeddings(self, new_emb): |
| 38 | + self.lm_head = new_emb |
| 39 | + |
| 40 | + def tie_weights(self): |
| 41 | + out_emb = self.get_output_embeddings() # lm_head (Linear) |
| 42 | + in_emb = self.get_input_embeddings() # Embedding |
| 43 | + |
| 44 | + # If either side is missing, do nothing |
| 45 | + if out_emb is None or in_emb is None: |
| 46 | + return |
| 47 | + |
| 48 | + out_w = out_emb.weight |
| 49 | + in_w = in_emb.weight |
| 50 | + |
| 51 | + if in_w.device.type == "meta" and out_w.device.type != "meta": |
| 52 | + # IMPORTANT: rebind the Parameter, not just copy data |
| 53 | + in_emb.weight = out_w |
| 54 | + return |
| 55 | + |
| 56 | + if out_w.device.type == "meta" and in_w.device.type != "meta": |
| 57 | + out_emb.weight = in_w |
| 58 | + return |
| 59 | + |
| 60 | + # Default HF behavior (ties by reference or clones as needed) |
| 61 | + self._tie_or_clone_weights(out_emb, in_emb) |
| 62 | + |
| 63 | + |
| 64 | + def forward(self, input_ids=None, labels=None, **kwargs): |
| 65 | + # input_ids: (batch, seq_len) |
| 66 | + x = self.backbone.embed(input_ids) # (B, T, H) |
| 67 | + x = self.backbone.mlp(x) # (B, T, H) |
| 68 | + logits = self.lm_head(x) # (B, T, V) |
| 69 | + |
| 70 | + loss = None |
| 71 | + if labels is not None: |
| 72 | + # classic language-model loss with next-token prediction |
| 73 | + shift_logits = logits[:, :-1, :].contiguous() |
| 74 | + shift_labels = labels[:, 1:].contiguous() |
| 75 | + loss = nn.CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), |
| 76 | + shift_labels.view(-1)) |
| 77 | + return CausalLMOutput(loss=loss, logits=logits) |
| 78 | + |
0 commit comments