Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions edsnlp/metrics/doc_classif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import Any, Dict, Iterable, Optional, Tuple, Union

from spacy.tokens import Doc
from spacy.training import Example

from edsnlp import registry
from edsnlp.metrics import make_examples


def doc_classification_metric(
examples: Union[Tuple[Iterable[Doc], Iterable[Doc]], Iterable[Example]],
label_attr: str = "label",
micro_key: str = "micro",
macro_key: str = "macro",
filter_expr: Optional[str] = None,
) -> Dict[str, Any]:
"""
Scores document-level classification (accuracy, precision, recall, F1).
Parameters
----------
examples: Examples
The examples to score, either a tuple of (golds, preds) or a list of
spacy.training.Example objects
label_attr: str
The Doc._ attribute containing the label
micro_key: str
The key to use to store the micro-averaged results
macro_key: str
The key to use to store the macro-averaged results
filter_expr: str
The filter expression to use to filter the documents
Returns
-------
Dict[str, Any]
"""
examples = make_examples(examples)
if filter_expr is not None:
filter_fn = eval(f"lambda doc: {filter_expr}")
examples = [eg for eg in examples if filter_fn(eg.reference)]

pred_labels = []
gold_labels = []
for eg in examples:
pred = getattr(eg.predicted._, label_attr, None)
gold = getattr(eg.reference._, label_attr, None)
pred_labels.append(pred)
gold_labels.append(gold)

labels = set(gold_labels) | set(pred_labels)
labels = {label for label in labels if label is not None}
results = {}

for label in labels:
tp = sum(
1 for p, g in zip(pred_labels, gold_labels) if p == label and g == label
)
fp = sum(
1 for p, g in zip(pred_labels, gold_labels) if p == label and g != label
)
fn = sum(
1 for p, g in zip(pred_labels, gold_labels) if g == label and p != label
)

precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (
(2 * precision * recall) / (precision + recall)
if (precision + recall) > 0
else 0.0
)

results[label] = {
"f": f1,
"p": precision,
"r": recall,
"tp": tp,
"fp": fp,
"fn": fn,
"support": tp + fn,
"positives": tp + fp,
}

total_tp = sum(1 for p, g in zip(pred_labels, gold_labels) if p == g)
total_fp = sum(1 for p, g in zip(pred_labels, gold_labels) if p != g)
total_fn = total_fp

micro_precision = (
total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
)
micro_recall = (
total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
)
micro_f1 = (
(2 * micro_precision * micro_recall) / (micro_precision + micro_recall)
if (micro_precision + micro_recall) > 0
else 0.0
)
accuracy = total_tp / len(pred_labels) if len(pred_labels) > 0 else 0.0

results[micro_key] = {
"accuracy": accuracy,
"f": micro_f1,
"p": micro_precision,
"r": micro_recall,
"tp": total_tp,
"fp": total_fp,
"fn": total_fn,
"support": len(gold_labels),
"positives": len(pred_labels),
}

per_class_precisions = [results[label]["p"] for label in labels]
per_class_recalls = [results[label]["r"] for label in labels]
per_class_f1s = [results[label]["f"] for label in labels]

macro_precision = (
sum(per_class_precisions) / len(per_class_precisions)
if per_class_precisions
else 0.0
)
macro_recall = (
sum(per_class_recalls) / len(per_class_recalls) if per_class_recalls else 0.0
)
macro_f1 = sum(per_class_f1s) / len(per_class_f1s) if per_class_f1s else 0.0

results[macro_key] = {
"f": macro_f1,
"p": macro_precision,
"r": macro_recall,
"support": len(labels),
"classes": len(labels),
}
return results


@registry.metrics.register("eds.doc_classification")
class DocClassificationMetric:
def __init__(
self,
label_attr: str = "label",
micro_key: str = "micro",
filter_expr: Optional[str] = None,
):
self.label_attr = label_attr
self.micro_key = micro_key
self.filter_expr = filter_expr

def __call__(self, *examples):
return doc_classification_metric(
examples,
label_attr=self.label_attr,
micro_key=self.micro_key,
filter_expr=self.filter_expr,
)


__all__ = [
"doc_classification_metric",
"DocClassificationMetric",
]
2 changes: 2 additions & 0 deletions edsnlp/pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,7 @@
from .trainable.embeddings.span_pooler.factory import create_component as span_pooler
from .trainable.embeddings.transformer.factory import create_component as transformer
from .trainable.embeddings.text_cnn.factory import create_component as text_cnn
from .trainable.embeddings.doc_pooler.factory import create_component as doc_pooler
from .trainable.doc_classifier.factory import create_component as doc_classifier
from .misc.split import Split as split
from .misc.explode import Explode as explode
1 change: 1 addition & 0 deletions edsnlp/pipes/trainable/doc_classifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .factory import create_component
222 changes: 222 additions & 0 deletions edsnlp/pipes/trainable/doc_classifier/doc_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import os
import pickle
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union

import torch
from spacy.tokens import Doc
from typing_extensions import NotRequired, TypedDict

from edsnlp.core.pipeline import PipelineProtocol
from edsnlp.core.torch_component import BatchInput, TorchComponent
from edsnlp.pipes.base import BaseComponent
from edsnlp.pipes.trainable.embeddings.typing import (
WordContextualizerComponent,
WordEmbeddingComponent,
)
from edsnlp.utils.bindings import Attributes

DocClassifierBatchInput = TypedDict(
"DocClassifierBatchInput",
{
"embedding": BatchInput,
"targets": NotRequired[torch.Tensor],
},
)

DocClassifierBatchOutput = TypedDict(
"DocClassifierBatchOutput",
{
"loss": Optional[torch.Tensor],
"labels": Optional[torch.Tensor],
},
)


class TrainableDocClassifier(
TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput],
BaseComponent,
):
def __init__(
self,
nlp: Optional[PipelineProtocol] = None,
name: str = "doc_classifier",
*,
embedding: Union[WordEmbeddingComponent, WordContextualizerComponent],
num_classes: Optional[int] = None,
label_attr: str = "label",
label2id: Optional[Dict[str, int]] = None,
id2label: Optional[Dict[int, str]] = None,
loss_fn=None,
labels: Optional[Sequence[str]] = None,
class_weights: Optional[Union[Dict[str, float], str]] = None,
):
self.label_attr: Attributes = label_attr
self.label2id = label2id or {}
self.id2label = id2label or {}
self.labels = labels
self.class_weights = class_weights

super().__init__(nlp, name)
self.embedding = embedding

self._loss_fn = loss_fn
self.loss_fn = None

if not hasattr(self.embedding, "output_size"):
raise ValueError(
"The embedding component must have an 'output_size' attribute."
)
embedding_size = self.embedding.output_size
if num_classes:
self.classifier = torch.nn.Linear(embedding_size, num_classes)

def _compute_class_weights(self, freq_dict: Dict[str, int]) -> torch.Tensor:
"""
Compute class weights from frequency dictionary.
Uses inverse frequency weighting: weight = 1 / frequency
"""
total_samples = sum(freq_dict.values())

weights = torch.zeros(len(self.label2id))

for label, freq in freq_dict.items():
if label in self.label2id:
weight = total_samples / (len(self.label2id) * freq)
weights[self.label2id[label]] = weight

return weights

def _load_class_weights_from_file(self, filepath: str) -> Dict[str, int]:
"""Load class weights from pickle file."""
with open(filepath, "rb") as f:
return pickle.load(f)

def set_extensions(self) -> None:
super().set_extensions()
if not Doc.has_extension(self.label_attr):
Doc.set_extension(self.label_attr, default={})

def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
if not self.label2id:
if self.labels is not None:
labels = set(self.labels)
else:
labels = set()
for doc in gold_data:
label = getattr(doc._, self.label_attr, None)
if isinstance(label, str):
labels.add(label)
if labels:
self.label2id = {}
self.id2label = {}
for i, label in enumerate(labels):
self.label2id[label] = i
self.id2label[i] = label
print("num classes:", len(self.label2id))
self.classifier = torch.nn.Linear(
self.embedding.output_size, len(self.label2id)
)

weight_tensor = None
if self.class_weights is not None:
if isinstance(self.class_weights, str):
freq_dict = self._load_class_weights_from_file(self.class_weights)
weight_tensor = self._compute_class_weights(freq_dict)
elif isinstance(self.class_weights, dict):
weight_tensor = self._compute_class_weights(self.class_weights)

print(f"Using class weights: {weight_tensor}")

if self._loss_fn is not None:
self.loss_fn = self._loss_fn
else:
self.loss_fn = torch.nn.CrossEntropyLoss(weight=weight_tensor)

super().post_init(gold_data, exclude=exclude)

def preprocess(self, doc: Doc) -> Dict[str, Any]:
return {"embedding": self.embedding.preprocess(doc)}

def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]:
preps = self.preprocess(doc)
label = getattr(doc._, self.label_attr, None)
if label is None:
raise ValueError(
f"Document does not have a gold label in 'doc._.{self.label_attr}'"
)
if isinstance(label, str) and self.label2id:
if label not in self.label2id:
raise ValueError(f"Label '{label}' not in label2id mapping.")
label = self.label2id[label]
return {
**preps,
"targets": torch.tensor(label, dtype=torch.long),
}

def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput:
embeddings = self.embedding.collate(batch["embedding"])
batch_input: DocClassifierBatchInput = {"embedding": embeddings}
if "targets" in batch:
batch_input["targets"] = torch.stack(batch["targets"])
return batch_input

def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
"""
Forward pass: compute embeddings, classify, and calculate loss
if targets provided.
"""
pooled = self.embedding(batch["embedding"])
embeddings = pooled["embeddings"]

logits = self.classifier(embeddings)

output: DocClassifierBatchOutput = {}
if "targets" in batch:
loss = self.loss_fn(logits, batch["targets"])
output["loss"] = loss
output["labels"] = None
else:
output["loss"] = None
output["labels"] = torch.argmax(logits, dim=-1)
return output

def postprocess(self, docs, results, input):
labels = results["labels"]
if isinstance(labels, torch.Tensor):
labels = labels.tolist()
for doc, label in zip(docs, labels):
if self.id2label and isinstance(label, int):
label = self.id2label.get(label, label)
setattr(doc._, self.label_attr, label)
return docs

def to_disk(self, path, *, exclude=set()):
"""Save classifier state to disk."""
repr_id = object.__repr__(self)
if repr_id in exclude:
return
exclude.add(repr_id)
os.makedirs(path, exist_ok=True)
data_path = path / "label_attr.pkl"
with open(data_path, "wb") as f:
pickle.dump(
{
"label_attr": self.label_attr,
"label2id": self.label2id,
"id2label": self.id2label,
},
f,
)
return super().to_disk(path, exclude=exclude)

@classmethod
def from_disk(cls, path, **kwargs):
"""Load classifier from disk."""
data_path = path / "label_attr.pkl"
with open(data_path, "rb") as f:
data = pickle.load(f)
obj = super().from_disk(path, **kwargs)
obj.label_attr = data.get("label_attr", "label")
obj.label2id = data.get("label2id", {})
obj.id2label = data.get("id2label", {})
return obj
Loading
Loading