-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
35 lines (30 loc) · 1.32 KB
/
Copy pathmodel.py
File metadata and controls
35 lines (30 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from attention import Embedder, GPTBlock
class MiniGPT(nn.Module):
"""
Mini GPT-style decoder-only transformer for next-token prediction.
"""
def __init__(self, d_vocab, n_layers, d_model, n_heads, d_ffn, max_len, dropout=0.1):
super().__init__()
self.embedder = Embedder(d_vocab, d_model, max_len)
self.blocks = nn.ModuleList([
GPTBlock(d_model=d_model, n_heads=n_heads, d_ffn=d_ffn, dropout=dropout, max_len=max_len)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
# self.head = nn.Linear(d_model, d_vocab)
self.head = nn.Linear(d_model, d_vocab, bias=False)
self.head.weight = self.embedder.token_emb.weight # tie output weights to input embedding
self.max_len = max_len
def forward(self, x, key_padding_mask=None):
x = self.embedder(x)
for block in self.blocks:
x = block(x, key_padding_mask=key_padding_mask)
x = self.norm(x)
x = self.head(x)
return x
def save_weights(self, path="./checkpoints/gpt_weights.pt"):
torch.save(self.state_dict(), path)
def load_weights(self, path="./checkpoints/test", strict=True):
self.load_state_dict(torch.load(path, map_location='cpu'), strict=strict)