Skip to content

Commit 7982a0f

Browse files
bilalawsmalhotra18
andauthored
feat: add support for non-deterministic models in GeneralSemanticRobustness and add BERTScore Dissimilarity (#184)
* Add the BERTScore computation to general semantic robustness * fixing lint issue * use the correct default model for BERTScore * Changes to allow mocking helper models in unit tests * Foreceful ray shutdown in general semantic robustness integ test to avoid out of memory errors * Change BERTScore from a similarity metric to a dissimilarity metric. Add the normalization factor for stochastic models. * Fix the bug where max should be used instead of min in the baseline * Enable verbose logging for integ tests to debug test failures * adding additional verbose flags for integ test * Remove redundant ray shutdown as we have a autouse fixture now to take care of it * Removing perturbation types based integ tests to save time consumed and disk space * Move the BertscoreHelperModels enumerator and tests. Unroll the arguments of get_meteor_score and get_bert_score out of kwargs. --------- Co-authored-by: Aman Malhotra <[email protected]>
1 parent c18b93d commit 7982a0f

File tree

10 files changed

+472
-223
lines changed

10 files changed

+472
-223
lines changed

src/fmeval/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,6 @@ class DatasetColumns(Enum):
121121
JUMPSTART_BUCKET_BASE_URL_FORMAT = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com"
122122
JUMPSTART_BUCKET_BASE_URL_FORMAT_ENV_VAR = "JUMPSTART_BUCKET_BASE_URL_FORMAT"
123123
GENERATED_TEXT_JMESPATH_EXPRESSION = "*.output_keys.generated_text"
124+
125+
# BERTScore
126+
BERTSCORE_DEFAULT_MODEL = "microsoft/deberta-xlarge-mnli"

src/fmeval/eval_algorithms/general_semantic_robustness.py

Lines changed: 142 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import functools
12
import itertools
23
import logging
34

45
import evaluate as hf_evaluate
56
from dataclasses import dataclass
67
from typing import Optional, List, Dict, Any
7-
8+
import numpy as np
9+
from ray.data import Dataset
810

911
from fmeval import util
1012
from fmeval.constants import (
@@ -47,6 +49,9 @@
4749
from fmeval.model_runners.composers.composers import PromptComposer
4850
from fmeval.model_runners.model_runner import ModelRunner
4951
from fmeval.perf_util import timed_block
52+
from fmeval.eval_algorithms.util import get_bert_score
53+
from fmeval.constants import BERTSCORE_DEFAULT_MODEL
54+
from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModel, BertscoreHelperModelTypes
5055

5156
logger = logging.getLogger(__name__)
5257

@@ -58,6 +63,7 @@
5863
}
5964

6065
WER_SCORE = "word_error_rate"
66+
BERT_SCORE_DISSIMILARITY = "bertscore_dissimilarity"
6167

6268

6369
@dataclass(frozen=True)
@@ -67,6 +73,9 @@ class GeneralSemanticRobustnessConfig(EvalAlgorithmConfig):
6773
6874
:param perturbation_type: perturbation type for generating perturbed inputs
6975
:param num_perturbations: Number of perturbed inputs to be generated for robustness evaluation
76+
:param num_baseline_samples: Only used for non-deterministic models. Number of times we generate
77+
the model output with the same input to compute the "baseline" change in model output. We
78+
compute differences between all pairs of outputs, i.e. between comb(num_baseline_samples, 2) pairs.
7079
:param butter_finger_perturbation_prob: The probability that a given character will be perturbed. Used for
7180
butter_finger perturbation_type
7281
:param random_uppercase_corrupt_proportion: Fraction of characters to be changed to uppercase. Used for
@@ -75,34 +84,63 @@ class GeneralSemanticRobustnessConfig(EvalAlgorithmConfig):
7584
whitespace_add_remove perturbation_type
7685
:param whitespace_add_prob: Given a non-whitespace, add a whitespace before it with this probability. Used for
7786
whitespace_add_remove perturbation_type
87+
:param model_type_for_bertscore: model to use for bert score
7888
"""
7989

8090
perturbation_type: str = BUTTER_FINGER
8191
num_perturbations: int = 5
92+
num_baseline_samples: int = 4
8293
butter_finger_perturbation_prob: float = 0.1
8394
random_uppercase_corrupt_proportion: float = 0.1
8495
whitespace_remove_prob: float = 0.1
8596
whitespace_add_prob: float = 0.05
97+
model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL
8698

8799
def __post_init__(self):
88100
if self.perturbation_type not in PERTURBATION_TYPE_TO_HELPER_CLASS.keys():
89101
raise EvalAlgorithmClientError(
90102
f"Invalid perturbation type '{self.perturbation_type} requested, please "
91103
f"choose from acceptable values: {PERTURBATION_TYPE_TO_HELPER_CLASS.keys()}"
92104
)
105+
if not BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore):
106+
raise EvalAlgorithmClientError(
107+
f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in "
108+
f"GeneralSemanticRobustnessConfig, please choose from acceptable values: {BertscoreHelperModelTypes.model_list()}."
109+
)
110+
if self.num_baseline_samples < 2:
111+
raise EvalAlgorithmClientError(
112+
f"Invalid num_baseline_samples: {self.num_baseline_samples} in GeneralSemanticRobusntessConfig. "
113+
f"The value should be at least 2."
114+
)
93115

94116

95117
class GeneralSemanticRobustness(EvalAlgorithmInterface):
96118
"""
97-
Semantic Robustness Eval algorithm for General task LLMs
119+
Semantic Robustness Eval algorithm for General task LLMs.
98120
99121
This evaluation measures how much the model output changes as a result of semantic preserving
100122
perturbations. Given the input, e.g., "A quick brown fox jumps over the lazy dog", the
101123
evaluation creates a perturbation that preserves the semantic meaning of the input e.g.,
102124
whitespace perturbation that changes the input text to "A q uick bro wn fox ju mps overthe lazy
103125
dog". The evaluation then measures how much the model output changes when prompted with the
104-
original vs. perturbed input. The output difference is measured using Word Error Rate (WER).
105-
https://huggingface.co/spaces/evaluate-metric/wer
126+
original vs. perturbed input.
127+
128+
The output difference is measured using two metrics: the Word Error Rate
129+
(https://huggingface.co/spaces/evaluate-metric/wer) and the BERTScore Dissimilarity, which is
130+
1 - BERTScore (https://huggingface.co/spaces/evaluate-metric/bertscore), between the original
131+
and the perturbed outputs. Word Error Rate measures syntactic differences, that is, changes in
132+
the words, whereas BERTScore Dissimilarity measures semantic differences. Semantic differences
133+
account of cases when the precise words in the output change but the meaning is the same, e.g.,
134+
consider the outputs "it is pouring down today" vs. "it is very rainy today".
135+
136+
Note: When the model generation strategy is non-deterministic (e.g., with non-zero temperature),
137+
the output can change even if the input is the same. In such scenarios, reporting differences
138+
(using Word Error Rate or BERTScore Dissimilarity) between the model output on the original input
139+
and perturbed inputs might show artificially low robustness since the model output changes even
140+
without a change in the input. So this evaluation normalizes the robustness score to account for
141+
the baseline non-determinism. Specifically, if d is a score (Word Error Rate or BERTScore
142+
Dissimilarity), then the evaluation reports max(0, d - d_base) where d_base measures the
143+
differences between the model output on the same input.
106144
"""
107145

108146
def __init__(self, eval_algorithm_config: GeneralSemanticRobustnessConfig = GeneralSemanticRobustnessConfig()):
@@ -126,6 +164,48 @@ def __init__(self, eval_algorithm_config: GeneralSemanticRobustnessConfig = Gene
126164
self._eval_algorithm_config.whitespace_remove_prob, self._eval_algorithm_config.whitespace_add_prob
127165
)
128166

167+
self._bertscore_helper_model = BertscoreHelperModel.remote(
168+
model_type=self._eval_algorithm_config.model_type_for_bertscore
169+
)
170+
171+
def _compute_baseline_scores(
172+
self, model: ModelRunner, original_prompt: str, original_model_output
173+
) -> Dict[str, float]:
174+
"""
175+
Private method for computing baseline scores. The baseline scores are required when the model
176+
output is non-deterministic and measure the change in the model output with the same input.
177+
See the class documentation for how the baseline scores are computed and used.
178+
179+
:param model: An instance of ModelRunner which is the model under evaluation
180+
:param original_prompt: The input prompt to the model. Assumes that the input is already
181+
embedded into the prompt template.
182+
:param original_model_output: The output of the model on the original input prompt.
183+
184+
:return: A dict containing the score name to baseline score value mapping.
185+
"""
186+
model_outputs = [
187+
model.predict(original_prompt)[0] for _ in range(self._eval_algorithm_config.num_baseline_samples - 1)
188+
]
189+
model_outputs.append(original_model_output)
190+
all_pairs = itertools.combinations(model_outputs, 2)
191+
first_output, second_output = zip(*all_pairs)
192+
baselines = dict()
193+
194+
baselines[BERT_SCORE_DISSIMILARITY] = 1 - np.mean(
195+
list(
196+
map(
197+
functools.partial(get_bert_score, helper_model=self._bertscore_helper_model),
198+
first_output,
199+
second_output,
200+
)
201+
)
202+
)
203+
204+
wer = hf_evaluate.load("wer")
205+
baselines[WER_SCORE] = wer.compute(predictions=first_output, references=second_output)
206+
207+
return baselines
208+
129209
def evaluate_sample(
130210
self,
131211
model_input: str,
@@ -151,9 +231,9 @@ def evaluate_sample(
151231
original_prompt = prompt_composer.compose(model_input)
152232
original_model_output = model_output if model_output else model.predict(original_prompt)[0]
153233

154-
if self._is_model_deterministic is None:
155-
if model.predict(original_prompt)[0] != original_model_output:
156-
raise EvalAlgorithmClientError("For evaluating semantic robustness, the model should be deterministic.")
234+
is_model_deterministic = self._is_model_deterministic
235+
if is_model_deterministic is None:
236+
is_model_deterministic = model.predict(original_prompt)[0] == original_model_output
157237

158238
perturbation = PERTURBATION_TYPE_TO_HELPER_CLASS[self._eval_algorithm_config.perturbation_type]()
159239
perturbed_inputs = perturbation.perturb(
@@ -164,18 +244,31 @@ def evaluate_sample(
164244
perturbed_input_prompts = [prompt_composer.compose(perturbed_input) for perturbed_input in perturbed_inputs]
165245
perturbed_input_outputs = [model.predict(prompt)[0] for prompt in perturbed_input_prompts]
166246

247+
bert_score_dissimilarity_value = 1 - np.mean(
248+
list(
249+
map(
250+
functools.partial(get_bert_score, helper_model=self._bertscore_helper_model),
251+
itertools.repeat(original_model_output, len(perturbed_input_outputs)),
252+
perturbed_input_outputs,
253+
)
254+
)
255+
)
167256
wer = hf_evaluate.load("wer")
257+
wer_value = wer.compute(
258+
predictions=perturbed_input_outputs,
259+
references=list(itertools.repeat(original_model_output, self._eval_algorithm_config.num_perturbations)),
260+
)
168261

169-
return [
170-
EvalScore(
171-
name=WER_SCORE,
172-
value=wer.compute(
173-
predictions=perturbed_input_outputs,
174-
references=list(
175-
itertools.repeat(original_model_output, self._eval_algorithm_config.num_perturbations)
176-
),
177-
),
262+
if not is_model_deterministic: # Compute the baseline differences in the model outputs for the same input
263+
baselines = self._compute_baseline_scores(model, original_prompt, original_model_output)
264+
bert_score_dissimilarity_value = max(
265+
0, bert_score_dissimilarity_value - baselines[BERT_SCORE_DISSIMILARITY]
178266
)
267+
wer_value = max(0, wer_value - baselines[WER_SCORE])
268+
269+
return [
270+
EvalScore(name=BERT_SCORE_DISSIMILARITY, value=bert_score_dissimilarity_value),
271+
EvalScore(name=WER_SCORE, value=wer_value),
179272
]
180273

181274
def evaluate(
@@ -223,33 +316,18 @@ def evaluate(
223316
)
224317

225318
self._is_model_deterministic = verify_model_determinism(model, dataset, DatasetColumns.PROMPT.value.name)
226-
if not self._is_model_deterministic:
227-
raise EvalAlgorithmClientError("For evaluating semantic robustness, the model should be deterministic.")
228319
dataset = generate_model_predict_response_for_dataset(
229320
model=model,
230321
data=dataset,
231322
model_input_column_name=DatasetColumns.PROMPT.value.name,
232323
model_output_column_name=DatasetColumns.MODEL_OUTPUT.value.name,
233324
)
234-
with timed_block(f"Computing score and aggregation on dataset {dataset_config.dataset_name}", logger):
235-
236-
def _generate_general_semantic_robustness_score(
237-
row: Dict[str, Any]
238-
) -> Dict[str, Any]: # pragma: no cover
239-
"""
240-
Map function generating the scores for every input record in input dataset
241-
"""
242-
row[WER_SCORE] = self.evaluate_sample(
243-
model_input=row[DatasetColumns.MODEL_INPUT.value.name],
244-
model=model,
245-
model_output=row[DatasetColumns.MODEL_OUTPUT.value.name],
246-
prompt_template=dataset_prompt_template,
247-
)[0].value
248-
return row
325+
with (timed_block(f"Computing score and aggregation on dataset {dataset_config.dataset_name}", logger)):
249326

250-
dataset = dataset.map(_generate_general_semantic_robustness_score).materialize()
251-
252-
dataset_scores, category_scores = aggregate_evaluation_scores(dataset, [WER_SCORE], agg_method=MEAN)
327+
dataset = self.__add_scores_to_dataset(dataset, model, dataset_prompt_template)
328+
dataset_scores, category_scores = aggregate_evaluation_scores(
329+
dataset, [BERT_SCORE_DISSIMILARITY, WER_SCORE], agg_method=MEAN
330+
)
253331
eval_outputs.append(
254332
EvalOutput(
255333
eval_name=self.eval_name,
@@ -268,7 +346,7 @@ def _generate_general_semantic_robustness_score(
268346
if save:
269347
save_dataset(
270348
dataset=dataset,
271-
score_names=[WER_SCORE],
349+
score_names=[BERT_SCORE_DISSIMILARITY, WER_SCORE],
272350
path=generate_output_dataset_path(
273351
path_to_parent_dir=self._eval_results_path,
274352
eval_name=self.eval_name,
@@ -277,3 +355,30 @@ def _generate_general_semantic_robustness_score(
277355
)
278356

279357
return eval_outputs
358+
359+
def __add_scores_to_dataset(self, dataset: Dataset, model: ModelRunner, prompt_template: str):
360+
"""
361+
Private method to encapsulate logic around getting scores for every row in the dataset.
362+
363+
:param dataset: ray Dataset to be used for eval scores generation
364+
:param model: An instance of ModelRunner which is the model under evaluation
365+
:param prompt_template: Eval algo config
366+
:returns: ray Dataset with score columns
367+
"""
368+
369+
def _generate_general_semantic_robustness_score(row: Dict[str, Any]) -> Dict[str, Any]: # pragma: no cover
370+
"""
371+
Map function generating the scores for every input record in input dataset
372+
"""
373+
scores = self.evaluate_sample(
374+
model_input=row[DatasetColumns.MODEL_INPUT.value.name],
375+
model=model,
376+
model_output=row[DatasetColumns.MODEL_OUTPUT.value.name],
377+
prompt_template=prompt_template,
378+
)
379+
row[BERT_SCORE_DISSIMILARITY] = scores[0].value
380+
row[WER_SCORE] = scores[1].value
381+
382+
return row
383+
384+
return dataset.map(_generate_general_semantic_robustness_score).materialize() # pragma: no cover

src/fmeval/eval_algorithms/helper_models/helper_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
import ray
23
import numpy as np
34
import evaluate as hf_evaluate
@@ -219,3 +220,27 @@ def get_helper_scores(self, target_output: str, model_output: str) -> float: #
219220
references=[target_output],
220221
model_type=self._model_type,
221222
)["f1"][0]
223+
224+
225+
class BertscoreHelperModelTypes(Enum):
226+
"""This class holds the names of all the allowed models for computing the BERTScore."""
227+
228+
MICROSOFT_DEBERTA_MODEL = "microsoft/deberta-xlarge-mnli"
229+
ROBERTA_MODEL = "roberta-large-mnli"
230+
231+
@classmethod
232+
def model_is_allowed(cls, model_name: str) -> bool:
233+
"""
234+
Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore.
235+
"""
236+
for elem in iter(cls):
237+
if elem.value == model_name:
238+
return True
239+
return False
240+
241+
@classmethod
242+
def model_list(cls) -> List[str]:
243+
"""
244+
Return a list of all the allowed models for computing BERTScore.
245+
"""
246+
return [elem.value for elem in iter(cls)]

0 commit comments

Comments
 (0)