Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ spacy==3.7.5
streamlit==1.37.0
streamlit-authenticator==0.4.2
tabulate==0.9.0
transformers==4.52.4
transformers
trl==0.15.2
torchmetrics==1.4.1
unsloth==2025.3.17
unsloth
vllm==0.8.1
xlrd==1.2.0
53 changes: 53 additions & 0 deletions src/atgen/metrics/base_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import string
from pydantic import BaseModel
from typing import List, Optional
import re
from abc import ABC, abstractmethod



class MetricConfig(BaseModel):
aggregate: bool = True


class BaseMetric(ABC):
def __init__(self, config: MetricConfig):
self.config = config

@abstractmethod
def compute(self, predictions: List[str], references: List[str]) -> float:
raise NotImplementedError

def _preprocess_text(self, text: str,
do_lowercase: bool = True,
do_remove_punctuation: bool = True,
do_remove_extra_spaces: bool = True,
do_remove_stopwords: bool = False,
stopwords: Optional[list[str]] = None) -> str:
# Convert to lowercase
if do_lowercase:
text = text.lower()

# Remove punctuation
if do_remove_punctuation:
# Keep hyphens within words, remove other punctuation
text = re.sub(r'(?<!\w)-|-(?!\w)', ' ', text) # Replace standalone hyphens
translator = str.maketrans('', '', string.punctuation.replace('-', ''))
text = text.translate(translator)
text = re.sub(r'(?<!\w)-(?!\w)', '', text) # Remove remaining standalone hyphens

# Normalize whitespace
if do_remove_extra_spaces:
text = ' '.join(text.split())

# Remove stopwords
if do_remove_stopwords:
if stopwords is None:
import nltk
nltk.download('stopwords')
stopwords = nltk.corpus.stopwords.words('english')
words = text.split()
words = [w for w in words if w not in stopwords]
text = ' '.join(words)

return text.strip()
1 change: 1 addition & 0 deletions src/atgen/metrics/classic_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

82 changes: 82 additions & 0 deletions src/atgen/metrics/classic_metrics/abstractiveness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import List
from rouge_score import tokenize
import numpy as np
from atgen.metrics.base_metric import BaseMetric, MetricConfig
from nltk import ngrams
from nltk.stem import porter
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.translate.bleu_score import corpus_bleu



class AbstractivenessConfig(MetricConfig):
aggregate: bool = True

class Abstractiveness(BaseMetric):
def __init__(self, config: AbstractivenessConfig):
super().__init__(config)

def _calculate_ngram_overlap(self, summary, text, n=1, use_modified=True):
summary_ngrams = list(ngrams(summary, n))
text_ngrams = list(ngrams(text, n))

if len(summary_ngrams) > 0:
ngrams_intersection = set(summary_ngrams).intersection(set(text_ngrams))
if use_modified:
word_is_part_of_ngram_copied = [
any((x in ngram for ngram in ngrams_intersection)) for x in summary
]
return 1 - sum(word_is_part_of_ngram_copied) / len(
word_is_part_of_ngram_copied
)
else:
return sum([x not in ngrams_intersection for x in summary_ngrams]) / len(
summary_ngrams
)
return np.nan


def compute(self, predictions: List[str], references: List[str], sources: List[str], **kwargs) -> float:
stemmer = porter.PorterStemmer()
tokenized_preds = [tokenize.tokenize(x, stemmer) for x in predictions]
tokenized_texts = [tokenize.tokenize(x, stemmer) for x in sources]
if references is not None:
tokenized_refs = [tokenize.tokenize(x, stemmer) for x in references]
else:
tokenized_refs = tokenized_preds

result = {}
for use_modified in [False, True]:
for n in range(1, 5):
pred_ngram_overlaps = []
label_ngram_overlaps = []
for pred, label, text in zip(
tokenized_preds, tokenized_refs, tokenized_texts
):
pred_pair_ngram_overlap = self._calculate_ngram_overlap(
pred, text, n, use_modified
)
pred_ngram_overlaps.append(pred_pair_ngram_overlap)
if references is not None:
label_pair_ngram_overlap = self._calculate_ngram_overlap(
label, text, n, use_modified
)
label_ngram_overlaps.append(label_pair_ngram_overlap)
key = f"ngram_overlap_{n}" if use_modified else f"novel_ngrams_{n}"

pred_ngram_overlaps = np.array(pred_ngram_overlaps)
cond_abs = ~np.isnan(pred_ngram_overlaps)
result[key + "_abs"] = pred_ngram_overlaps[cond_abs]

if references is not None:
label_ngram_overlaps = np.array(label_ngram_overlaps)
cond_rel = cond_abs & ~np.isnan(label_ngram_overlaps)
result[key + "_rel"] = (
pred_ngram_overlaps[cond_rel] / label_ngram_overlaps[cond_rel]
)

if self.config.aggregate:
for key, value in result.items():
result[key] = np.mean(value)

return {"abstractiveness": result}
56 changes: 56 additions & 0 deletions src/atgen/metrics/classic_metrics/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from atgen.metrics.base_metric import BaseMetric, MetricConfig
from typing import List, Optional
from nltk import ngrams
from nltk.stem import porter
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.translate.bleu_score import corpus_bleu
import numpy as np


class BleuConfig(MetricConfig):
pass


class Bleu(BaseMetric):
def __init__(self, config: BleuConfig):
super().__init__(config)


def _smoothing_function(self, p_n, references, hypothesis, hyp_len):
smoothed_p_n = []
for i, p_i in enumerate(p_n, start=1):
# Smoothing is not applied for unigrams
if i > 1:
# If hypothesis length is lower than the current order, its value equals (0 + 1) / (0 + 1) = 0
if hyp_len < i:
assert p_i.denominator == 1
smoothed_p_n.append(1)
# Otherwise apply smoothing
else:
smoothed_p_i = (p_i.numerator + 1) / (p_i.denominator + 1)
smoothed_p_n.append(smoothed_p_i)
else:
smoothed_p_n.append(p_i)
return smoothed_p_n

def compute(self, predictions: List[str], references: List[str], sources: Optional[List[str]] = None, **kwargs) -> float:
scores = []
for pred, ref in zip(predictions, references):
if isinstance(ref, str):
ref_list = [ref]
else:
ref_list = ref

tok_ref = [[word_tokenize(r) for r in ref_list]]
tok_pred = [word_tokenize(pred)]

try:
bleu_score = corpus_bleu(tok_ref, tok_pred, smoothing_function=self._smoothing_function)
scores.append(bleu_score)
except (KeyError, ZeroDivisionError):
scores.append(0.0)

if self.config.aggregate:
return {"bleu": float(np.mean(scores))}
else:
return {"bleu": scores}
31 changes: 31 additions & 0 deletions src/atgen/metrics/classic_metrics/exact_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from atgen.metrics.base_metric import BaseMetric, MetricConfig
from typing import List, Optional
import numpy as np

class ExactMatchConfig(MetricConfig):
aggregate: bool = True


class ExactMatch(BaseMetric):
def __init__(self, config: ExactMatchConfig):
super().__init__(config)


def compute(self, predictions: List[str], references: List[str], sources: Optional[List[str]] = None, **kwargs) -> float:

if isinstance(references[0], list):
scores = np.array(
[
any(self._preprocess_text(pred) == self._preprocess_text(one_ref) for one_ref in ref)
for pred, ref in zip(predictions, references)
]
)
else:
scores = np.array(
[self._preprocess_text(pred) == self._preprocess_text(ref) for pred, ref in zip(predictions, references)]
)

if self.config.aggregate:
return {"exact_match": float(np.mean(scores))}
else:
return {"exact_match": scores}
26 changes: 26 additions & 0 deletions src/atgen/metrics/classic_metrics/exact_match_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from multiprocessing import reduction
from typing import Literal, Optional
from omegaconf import DictConfig
import numpy as np

from atgen.metrics.classic_metrics.base_metric import BaseMetric, BaseMetricConfig


class ExactMatchMathConfig(BaseMetricConfig):
aggregate: bool = True

class ExactMatchMath(BaseMetric):
def __init__(self, config: DictConfig):
super().__init__(config)

def compute(self, generated_texts: list[str], reference_texts: list[str], original_texts: list[str], task: Literal["summarization", "open-qa", "multi-choice-qa", "translation", "math"]) -> float:
scores = np.array(
[
pred.split("#### ")[-1].lower() == ref.split("#### ")[-1].lower()
for pred, ref in zip(generated_texts, reference_texts)
]
)
if self.config.aggregate:
return {"exact_match_math": float(np.mean(scores))}
else:
return {"exact_match_math": scores}
27 changes: 27 additions & 0 deletions src/atgen/metrics/classic_metrics/rouge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from atgen.metrics.base_metric import BaseMetric, MetricConfig
from typing import List, Optional
from evaluate import load
import numpy as np


class RougeConfig(MetricConfig):
use_stemmer: bool = True



class Rouge(BaseMetric):
def __init__(self, config: RougeConfig):
super().__init__(config)
self.rouge = load("rouge")

def compute(self, predictions: List[str], references: List[str], sources: Optional[List[str]] = None, **kwargs) -> float:
rouge_scores = self.rouge.compute(
predictions=predictions,
references=references,
use_stemmer=self.config.use_stemmer,
)

if self.config.aggregate:
return {k: float(np.mean(v)) for k, v in rouge_scores.items()}
else:
return rouge_scores
37 changes: 37 additions & 0 deletions src/atgen/metrics/classic_metrics/sacrebleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from time import time
from typing import List, Optional
from atgen.metrics.base_metric import BaseMetric, MetricConfig
from evaluate import load
import numpy as np



sacrebleu = load("sacrebleu")


class SacrebleuConfig(MetricConfig):
pass


class Sacrebleu(BaseMetric):
def __init__(self, config: SacrebleuConfig):
super().__init__(config)

def compute(self, predictions: List[str], references: List[str], sources: Optional[List[str]] = None, **kwargs) -> float:
if not isinstance(references[0], list):
sacrebleu_references = [[ref] for ref in references]
sacrebleu_result = sacrebleu.compute(
predictions=predictions, references=sacrebleu_references
)
return float(sacrebleu_result.pop("score"))
else:
sacrebleu_scores = []
for pred, ref in zip(predictions, references):
sacrebleu_result = sacrebleu.compute(
predictions=[pred], references=[ref]
)
sacrebleu_scores.append(sacrebleu_result.pop("score"))
if self.config.aggregate:
return {"sacrebleu": float(np.mean(sacrebleu_scores))}
else:
return {"sacrebleu": sacrebleu_scores}
36 changes: 36 additions & 0 deletions src/atgen/metrics/classic_metrics/word_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from atgen.metrics.base_metric import BaseMetric, MetricConfig
from typing import List, Optional
import numpy as np


class WordLengthConfig(MetricConfig):
pass


class WordLength(BaseMetric):
def __init__(self, config: WordLengthConfig):
super().__init__(config)

def compute(self, predictions: List[str], references: List[str], sources: Optional[List[str]] = None, **kwargs) -> float:
# Calculate generated text lengths
gen_word_lengths = np.array([len(text.split()) for text in predictions])

# Calculate reference text lengths
if isinstance(references[0], list):
ref_word_lengths = np.array(
[
np.mean([len(text.split()) for text in ref])
for ref in references
]
)
else:
ref_word_lengths = np.array([len(ref.split()) for ref in references])

# Avoid division by zero
ref_word_lengths_safe = np.where(ref_word_lengths > 0, ref_word_lengths, 1)
relative_lengths = gen_word_lengths / ref_word_lengths_safe

if self.config.aggregate:
return {"word_length": float(np.mean(relative_lengths))}
else:
return {"word_length": relative_lengths}
Loading
Loading