66
77class DistilBertCRFForTokenClassification (nn .Module ):
88 """
9- DistilBERT ➜ dropout ➜ linear projection ➜ CRF.
10- The CRF layer models label‑to‑label transitions, so the model
11- is optimised at *sequence* level rather than *token* level.
9+ Token-level classifier that combines DistilBERT with a CRF layer for structured prediction.
10+
11+ Architecture:
12+ input_ids, attention_mask
13+ ↓
14+ DistilBERT (pretrained encoder)
15+ ↓
16+ Dropout
17+ ↓
18+ Linear layer (projects hidden size → num_labels)
19+ ↓
20+ CRF layer (models sequence-level transitions)
21+
22+ Training:
23+ - Uses negative log-likelihood from CRF as loss.
24+ - Learns both emission scores (token-level confidence) and
25+ transition scores (label-to-label sequence consistency).
26+
27+ Inference:
28+ - Uses Viterbi decoding to predict the most likely sequence of labels.
29+
30+ Output:
31+ During training:
32+ {"loss": ..., "logits": ...}
33+ During inference:
34+ {"logits": ..., "predictions": List[List[int]]}
35+
36+ Example input shape:
37+ input_ids: [B, T] — e.g. [16, 128]
38+ attention_mask: [B, T] — 1 for real tokens, 0 for padding
39+ logits: [B, T, C] — C = number of label classes
1240 """
13- def __init__ (self ,
14- num_labels : int ,
15- id2label : dict ,
16- label2id : dict ,
17- pretrained_name : str = "distilbert-base-uncased" ,
18- dropout_prob : float = 0.1 ):
41+ def __init__ (self , num_labels : int , id2label : dict , label2id : dict , pretrained_name : str = "distilbert-base-uncased" , dropout_prob : float = 0.1 ):
1942 super ().__init__ ()
2043
2144 self .config = DistilBertConfig .from_pretrained (
@@ -29,11 +52,34 @@ def __init__(self,
2952 self .classifier = nn .Linear (self .config .hidden_size , num_labels )
3053 self .crf = CRF (num_labels , batch_first = True )
3154
32- def forward (self ,
33- input_ids = None ,
34- attention_mask = None ,
35- labels = None ,
36- ** kwargs ):
55+ def forward (self , input_ids = None , attention_mask = None , labels = None , ** kwargs ):
56+ """
57+ Forward pass for training or inference.
58+
59+ Args:
60+ input_ids (Tensor): Token IDs of shape [B, T]
61+ attention_mask (Tensor): Attention mask of shape [B, T]
62+ labels (Tensor, optional): Ground-truth labels of shape [B, T]. Required during training.
63+ kwargs: Any additional DistilBERT-compatible inputs (e.g., head_mask, position_ids, etc.)
64+
65+ Returns:
66+ If labels are provided (training mode):
67+ dict with:
68+ - loss (Tensor): scalar negative log-likelihood from CRF
69+ - logits (Tensor): emission scores of shape [B, T, C]
70+
71+ If labels are not provided (inference mode):
72+ dict with:
73+ - logits (Tensor): emission scores of shape [B, T, C]
74+ - predictions (List[List[int]]): decoded label IDs from CRF,
75+ one list per sequence,
76+ each of length T-2 (excluding [CLS] and [SEP])
77+
78+ Notes:
79+ - logits: [B, T, C], where B = batch size, T = sequence length, C = number of label classes
80+ - predictions: List[List[int]], where each inner list has length T-2
81+ (i.e., excludes [CLS] and [SEP]) and contains Viterbi-decoded label IDs
82+ """
3783
3884 # Hugging Face occasionally injects helper fields (e.g. num_items_in_batch)
3985 # Filter `kwargs` down to what DistilBertModel.forward actually accepts.
@@ -48,36 +94,49 @@ def forward(self,
4894 attention_mask = attention_mask ,
4995 ** bert_kwargs ,
5096 )
51- # —— Build emissions once ——————————————————————————————
52- sequence_output = self .dropout (outputs [0 ]) # [B, T, H]
53- emission_scores = self .classifier (sequence_output ) # [B, T, C]
97+ # 1) Compute per-token emission scores
98+ # Applies dropout to the BERT hidden states, then projects them to label logits.
99+ # Shape: [B, T, C], where B=batch size, T=sequence length, C=number of classes
100+ sequence_output = self .dropout (outputs [0 ])
101+ emission_scores = self .classifier (sequence_output )
54102
55- # ============================== TRAINING ==============================
56103 if labels is not None :
57- # 1. Drop [CLS] (idx 0) and [SEP] (idx –1)
58- emissions = emission_scores [:, 1 : - 1 , :] # [B, T‑2, C]
59- tags = labels [:, 1 : - 1 ]. clone () # [B, T‑2 ]
60- crf_mask = ( tags != - 100 ) # True = keep
104+ # 2) Remove [CLS] and [SEP] special tokens from emissions and labels
105+ # These tokens were added by the tokenizer but are not part of the identifier
106+ emissions = emission_scores [:, 1 : - 1 , :] # [B, T-2, C ]
107+ tags = labels [:, 1 : - 1 ]. clone () # [B, T-2]
61108
62- # 2. For any position that’s masked‑off ➜ set tag to a valid id (0)
109+ # 3) Create a mask: True where label is valid, False where label == -100
110+ # The CRF will use this to ignore special/padded tokens
111+ crf_mask = (tags != - 100 )
112+
113+ # 4) Replace invalid label positions (-100) with a dummy label (e.g., 0)
114+ # This is required because CRF expects a label at every position, even if masked
63115 tags [~ crf_mask ] = 0
64116
65- # 3. Guarantee first timestep is ON for every sequence
117+ # 5) Ensure the first token of every sequence is active in the CRF mask
118+ # This avoids CRF errors when the first token is masked out (which breaks decoding)
66119 first_off = (~ crf_mask [:, 0 ]).nonzero (as_tuple = True )[0 ]
67120 if len (first_off ):
68- crf_mask [first_off , 0 ] = True # flip mask to ON
69- tags [first_off , 0 ] = 0 # give it tag 0
121+ crf_mask [first_off , 0 ] = True
122+ tags [first_off , 0 ] = 0 # assign a dummy label
70123
124+ # 6) Compute CRF negative log-likelihood loss
71125 loss = - self .crf (emissions , tags , mask = crf_mask , reduction = "mean" )
72126 return {"loss" : loss , "logits" : emission_scores }
73127
74- # ============================= INFERENCE ==============================
75128 else :
76- crf_mask = attention_mask [:, 1 :- 1 ].bool () # [B, T‑2]
77- emissions = emission_scores [:, 1 :- 1 , :] # [B, T‑2, C]
129+ # INFERENCE MODE
130+
131+ # 2) Remove [CLS] and [SEP] from emissions and build CRF mask from attention
132+ # Only use the inner content of the input sequence
133+ crf_mask = attention_mask [:, 1 :- 1 ].bool () # [B, T-2]
134+ emissions = emission_scores [:, 1 :- 1 , :] # [B, T-2, C]
135+
136+ # 3) Run Viterbi decoding to get best label sequence for each input
78137 best_paths = self .crf .decode (emissions , mask = crf_mask )
79- return {"logits" : emission_scores ,
80- "predictions" : best_paths }
138+ return {"logits" : emission_scores , "predictions" : best_paths }
139+
81140 @classmethod
82141 def from_pretrained (cls , ckpt_dir , local = False , ** kw ):
83142 from safetensors .torch import load_file as load_safe_file
0 commit comments