Skip to content

Commit 26857a1

Browse files
author
Christian Newman
committed
Tested tree and lm based run and train. Did some thorough documenting on how the nn code works.
1 parent f067186 commit 26857a1

File tree

6 files changed

+349
-141
lines changed

6 files changed

+349
-141
lines changed

main

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,11 @@ if __name__ == "__main__":
6767
download_files()
6868
train_tree(config)
6969
elif args.model_type == "lm_based":
70-
download_files()
7170
train_lm(SCRIPT_DIR)
7271

7372
elif args.mode == "run":
7473
if args.model_type == "tree_based":
75-
config = load_config_tree()
74+
config = load_config_tree(SCRIPT_DIR)
7675
# Inject overrides
7776
download_files()
7877
config["model_type"] = args.model_type

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ pytorch-crf==0.7.2
66
scikit-learn==1.6.1
77
spiral @ git+https://github.com/cnewman/spiral.git@dff537320c15849c10e583968036df2d966eddee
88
torch==2.7.1
9-
transformers==4.52.4
109
waitress==3.0.2
10+
gensim==4.3.3
11+
transformers[torch]

src/lm_based_tagger/distilbert_crf.py

Lines changed: 90 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,39 @@
66

77
class DistilBertCRFForTokenClassification(nn.Module):
88
"""
9-
DistilBERT ➜ dropout ➜ linear projection ➜ CRF.
10-
The CRF layer models label‑to‑label transitions, so the model
11-
is optimised at *sequence* level rather than *token* level.
9+
Token-level classifier that combines DistilBERT with a CRF layer for structured prediction.
10+
11+
Architecture:
12+
input_ids, attention_mask
13+
14+
DistilBERT (pretrained encoder)
15+
16+
Dropout
17+
18+
Linear layer (projects hidden size → num_labels)
19+
20+
CRF layer (models sequence-level transitions)
21+
22+
Training:
23+
- Uses negative log-likelihood from CRF as loss.
24+
- Learns both emission scores (token-level confidence) and
25+
transition scores (label-to-label sequence consistency).
26+
27+
Inference:
28+
- Uses Viterbi decoding to predict the most likely sequence of labels.
29+
30+
Output:
31+
During training:
32+
{"loss": ..., "logits": ...}
33+
During inference:
34+
{"logits": ..., "predictions": List[List[int]]}
35+
36+
Example input shape:
37+
input_ids: [B, T] — e.g. [16, 128]
38+
attention_mask: [B, T] — 1 for real tokens, 0 for padding
39+
logits: [B, T, C] — C = number of label classes
1240
"""
13-
def __init__(self,
14-
num_labels: int,
15-
id2label: dict,
16-
label2id: dict,
17-
pretrained_name: str = "distilbert-base-uncased",
18-
dropout_prob: float = 0.1):
41+
def __init__(self, num_labels: int, id2label: dict, label2id: dict, pretrained_name: str = "distilbert-base-uncased", dropout_prob: float = 0.1):
1942
super().__init__()
2043

2144
self.config = DistilBertConfig.from_pretrained(
@@ -29,11 +52,34 @@ def __init__(self,
2952
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
3053
self.crf = CRF(num_labels, batch_first=True)
3154

32-
def forward(self,
33-
input_ids=None,
34-
attention_mask=None,
35-
labels=None,
36-
**kwargs):
55+
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
56+
"""
57+
Forward pass for training or inference.
58+
59+
Args:
60+
input_ids (Tensor): Token IDs of shape [B, T]
61+
attention_mask (Tensor): Attention mask of shape [B, T]
62+
labels (Tensor, optional): Ground-truth labels of shape [B, T]. Required during training.
63+
kwargs: Any additional DistilBERT-compatible inputs (e.g., head_mask, position_ids, etc.)
64+
65+
Returns:
66+
If labels are provided (training mode):
67+
dict with:
68+
- loss (Tensor): scalar negative log-likelihood from CRF
69+
- logits (Tensor): emission scores of shape [B, T, C]
70+
71+
If labels are not provided (inference mode):
72+
dict with:
73+
- logits (Tensor): emission scores of shape [B, T, C]
74+
- predictions (List[List[int]]): decoded label IDs from CRF,
75+
one list per sequence,
76+
each of length T-2 (excluding [CLS] and [SEP])
77+
78+
Notes:
79+
- logits: [B, T, C], where B = batch size, T = sequence length, C = number of label classes
80+
- predictions: List[List[int]], where each inner list has length T-2
81+
(i.e., excludes [CLS] and [SEP]) and contains Viterbi-decoded label IDs
82+
"""
3783

3884
# Hugging Face occasionally injects helper fields (e.g. num_items_in_batch)
3985
# Filter `kwargs` down to what DistilBertModel.forward actually accepts.
@@ -48,36 +94,49 @@ def forward(self,
4894
attention_mask=attention_mask,
4995
**bert_kwargs,
5096
)
51-
# —— Build emissions once ——————————————————————————————
52-
sequence_output = self.dropout(outputs[0]) # [B, T, H]
53-
emission_scores = self.classifier(sequence_output) # [B, T, C]
97+
# 1) Compute per-token emission scores
98+
# Applies dropout to the BERT hidden states, then projects them to label logits.
99+
# Shape: [B, T, C], where B=batch size, T=sequence length, C=number of classes
100+
sequence_output = self.dropout(outputs[0])
101+
emission_scores = self.classifier(sequence_output)
54102

55-
# ============================== TRAINING ==============================
56103
if labels is not None:
57-
# 1. Drop [CLS] (idx 0) and [SEP] (idx –1)
58-
emissions = emission_scores[:, 1:-1, :] # [B, T‑2, C]
59-
tags = labels[:, 1:-1].clone() # [B, T‑2]
60-
crf_mask = (tags != -100) # True = keep
104+
# 2) Remove [CLS] and [SEP] special tokens from emissions and labels
105+
# These tokens were added by the tokenizer but are not part of the identifier
106+
emissions = emission_scores[:, 1:-1, :] # [B, T-2, C]
107+
tags = labels[:, 1:-1].clone() # [B, T-2]
61108

62-
# 2. For any position that’s masked‑off ➜ set tag to a valid id (0)
109+
# 3) Create a mask: True where label is valid, False where label == -100
110+
# The CRF will use this to ignore special/padded tokens
111+
crf_mask = (tags != -100)
112+
113+
# 4) Replace invalid label positions (-100) with a dummy label (e.g., 0)
114+
# This is required because CRF expects a label at every position, even if masked
63115
tags[~crf_mask] = 0
64116

65-
# 3. Guarantee first timestep is ON for every sequence
117+
# 5) Ensure the first token of every sequence is active in the CRF mask
118+
# This avoids CRF errors when the first token is masked out (which breaks decoding)
66119
first_off = (~crf_mask[:, 0]).nonzero(as_tuple=True)[0]
67120
if len(first_off):
68-
crf_mask[first_off, 0] = True # flip mask to ON
69-
tags[first_off, 0] = 0 # give it tag 0
121+
crf_mask[first_off, 0] = True
122+
tags[first_off, 0] = 0 # assign a dummy label
70123

124+
# 6) Compute CRF negative log-likelihood loss
71125
loss = -self.crf(emissions, tags, mask=crf_mask, reduction="mean")
72126
return {"loss": loss, "logits": emission_scores}
73127

74-
# ============================= INFERENCE ==============================
75128
else:
76-
crf_mask = attention_mask[:, 1:-1].bool() # [B, T‑2]
77-
emissions = emission_scores[:, 1:-1, :] # [B, T‑2, C]
129+
# INFERENCE MODE
130+
131+
# 2) Remove [CLS] and [SEP] from emissions and build CRF mask from attention
132+
# Only use the inner content of the input sequence
133+
crf_mask = attention_mask[:, 1:-1].bool() # [B, T-2]
134+
emissions = emission_scores[:, 1:-1, :] # [B, T-2, C]
135+
136+
# 3) Run Viterbi decoding to get best label sequence for each input
78137
best_paths = self.crf.decode(emissions, mask=crf_mask)
79-
return {"logits": emission_scores,
80-
"predictions": best_paths}
138+
return {"logits": emission_scores, "predictions": best_paths}
139+
81140
@classmethod
82141
def from_pretrained(cls, ckpt_dir, local=False, **kw):
83142
from safetensors.torch import load_file as load_safe_file

src/lm_based_tagger/distilbert_preprocessing.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
import re
2-
from nltk import pos_tag
3-
import nltk
42
from difflib import SequenceMatcher
53
import pandas as pd
64
from datasets import Dataset
75

8-
# Download once (we’ll just do it quietly here)
9-
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
10-
nltk.download('universal_tagset', quiet=True)
11-
126
# === Constants ===
137
VOWELS = set("aeiou")
148
LOW_FREQ_TAGS = {"CJ", "VM", "PRE", "V"}
@@ -27,15 +21,13 @@
2721
"hungarian",
2822
"cvr",
2923
"digit",
30-
#"nltk"
3124
]
3225

3326
FEATURE_FUNCTIONS = {
3427
"context": lambda row, tokens: CONTEXT_MAP.get(row["CONTEXT"].strip().upper(), "@unknown"),
3528
"hungarian": lambda row, tokens: detect_hungarian_prefix(tokens[0]) if tokens else "@hung_none",
3629
"cvr": lambda row, tokens: consonant_vowel_ratio_bucket(tokens),
3730
"digit": lambda row, tokens: detect_digit_feature(tokens),
38-
"nltk": lambda row, tokens: "@nltk_" + '-'.join(tag.lower() for _, tag in pos_tag(tokens, tagset="universal"))
3931
}
4032

4133
def get_feature_tokens(row, tokens):
@@ -99,6 +91,38 @@ def normalize_language(lang_str):
9991
return "@lang_" + lang_str.strip().lower().replace("++", "pp").replace("#", "sharp")
10092

10193
def prepare_dataset(df: pd.DataFrame, label2id: dict):
94+
"""
95+
Converts a DataFrame of identifier tokens and grammar tags into a HuggingFace Dataset
96+
formatted for NER training with feature and position tokens.
97+
98+
Each row in the input DataFrame should contain:
99+
- tokens: List[str] (e.g., ['get', 'Employee', 'Name'])
100+
- tags: List[str] (e.g., ['V', 'NM', 'N'])
101+
- CONTEXT: str (e.g., 'function')
102+
103+
The function adds:
104+
- Feature tokens: ['@hung_get', '@no_digit', '@cvr_mid', '@func']
105+
- Interleaved position and real tokens:
106+
['@pos_0', 'get', '@pos_1', 'Employee', '@pos_2', 'Name']
107+
108+
The NER tags are aligned so that:
109+
- Feature tokens and position markers get label -100 (ignored in loss)
110+
- Real tokens are converted from grammar tags using `label2id`
111+
112+
Example Input:
113+
df = pd.DataFrame([{
114+
"tokens": ["get", "Employee", "Name"],
115+
"tags": ["V", "NM", "N"],
116+
"CONTEXT": "function"
117+
}])
118+
119+
Example Output:
120+
Dataset with:
121+
tokens: ['@hung_get', '@no_digit', '@cvr_mid', '@func',
122+
'@pos_0', 'get', '@pos_1', 'Employee', '@pos_2', 'Name']
123+
ner_tags: [-100, -100, -100, -100,
124+
-100, 1, -100, 2, -100, 3] # assuming label2id = {"V": 1, "NM": 2, "N": 3}
125+
"""
102126
rows = []
103127
for _, row in df.iterrows():
104128
tokens = row["tokens"]
@@ -123,9 +147,34 @@ def prepare_dataset(df: pd.DataFrame, label2id: dict):
123147
"ner_tags": [r["ner_tags"] for r in rows]
124148
})
125149

126-
def tokenize_and_align_labels(example, tokenizer):
150+
def tokenize_and_align_labels(sample, tokenizer):
151+
"""
152+
Tokenizes an example and aligns NER labels with subword tokens.
153+
154+
The input `example` comes from `prepare_dataset()` and contains:
155+
- tokens: List[str], including feature and position tokens
156+
- ner_tags: List[int], aligned with `tokens`, with -100 for ignored tokens
157+
158+
This function:
159+
- Uses `is_split_into_words=True` to tokenize each item in `tokens`
160+
- Uses `tokenizer.word_ids()` to map each subword back to its original token index
161+
- Assigns the corresponding label (or -100) for each subword token
162+
163+
Example Input:
164+
example = {
165+
"tokens": ['@hung_get', '@no_digit', '@cvr_mid', '@func',
166+
'@pos_0', 'get', '@pos_1', 'Employee', '@pos_2', 'Name'],
167+
"ner_tags": [-100, -100, -100, -100,
168+
-100, 1, -100, 2, -100, 3]
169+
}
170+
171+
Assuming 'Employee' is tokenized to ['Em', '##ployee'],
172+
Example Output:
173+
tokenized["labels"] = [-100, -100, -100, -100,
174+
-100, 1, -100, 2, 2, -100, 3]
175+
"""
127176
tokenized = tokenizer(
128-
example["tokens"],
177+
sample["tokens"],
129178
truncation=True,
130179
is_split_into_words=True
131180
)
@@ -136,8 +185,8 @@ def tokenize_and_align_labels(example, tokenizer):
136185
for word_id in word_ids:
137186
if word_id is None:
138187
labels.append(-100)
139-
elif word_id < len(example["ner_tags"]):
140-
labels.append(example["ner_tags"][word_id])
188+
elif word_id < len(sample["ner_tags"]):
189+
labels.append(sample["ner_tags"][word_id])
141190
else:
142191
labels.append(-100)
143192

0 commit comments

Comments
 (0)