diff --git a/delphi/__main__.py b/delphi/__main__.py index d69d7b10..f15493e8 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -22,7 +22,9 @@ from delphi.clients import Offline, OpenRouter from delphi.config import RunConfig from delphi.explainers import ContrastiveExplainer, DefaultExplainer, NoOpExplainer +from delphi.explainers.bestofk import BestOfKOrchestrator from delphi.explainers.explainer import ExplainerResult +from delphi.explainers.iterative import HillClimbingOrchestrator from delphi.latents import LatentCache, LatentDataset from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results @@ -171,13 +173,90 @@ async def process_cache( f"Explainer provider {run_cfg.explainer_provider} not supported" ) - if not run_cfg.explainer == "none": + # Define postprocessors used by all explainer modes + def explainer_postprocess(result, is_final=True): + with open(explanations_path / f"{result.record.latent}.txt", "wb") as f: + f.write(orjson.dumps(result.explanation)) + return result - def explainer_postprocess(result): - with open(explanations_path / f"{result.record.latent}.txt", "wb") as f: - f.write(orjson.dumps(result.explanation)) + def scorer_postprocess_fn(result, score_dir, round_idx=None, is_final=False): + safe_latent = str(result.record.latent).replace("/", "--") + if round_idx is not None and not is_final: + filename = f"{safe_latent}_round{round_idx}.txt" + else: + filename = f"{safe_latent}.txt" + with open(score_dir / filename, "wb") as f: + f.write(orjson.dumps(result.score)) + + # Handle BestOfK and Iterative orchestrators + if run_cfg.explainer in ("bestofk", "iterative"): + # Build scorers for orchestrators + scorers_with_paths = [] + for scorer_name in run_cfg.scorers: + scorer_path = scores_path / scorer_name + scorer_path.mkdir(parents=True, exist_ok=True) + + if scorer_name == "simulation": + scorer = OpenAISimulator( + llm_client, + tokenizer=tokenizer, + all_at_once=isinstance(llm_client, Offline), + ) + elif scorer_name == "fuzz": + scorer = FuzzingScorer( + llm_client, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + fuzz_type=run_cfg.fuzz_type, + ) + elif scorer_name == "detection": + scorer = DetectionScorer( + llm_client, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) + else: + raise ValueError(f"Scorer {scorer_name} not supported") + + scorers_with_paths.append((scorer, scorer_path)) + + if run_cfg.explainer == "bestofk": + orchestrator = BestOfKOrchestrator( + client=llm_client, + scorers_with_paths=scorers_with_paths, + num_explanations=run_cfg.bestofk_num_explanations, + judge_scorer_index=run_cfg.bestofk_judge_scorer_index, + num_train_examples=run_cfg.bestofk_num_train_examples, + verbose=run_cfg.verbose, + scorer_postprocess=scorer_postprocess_fn, + ) + else: # iterative + orchestrator = HillClimbingOrchestrator( + client=llm_client, + scorers_with_paths=scorers_with_paths, + num_rounds=run_cfg.iterative_num_rounds, + judge_scorer_index=run_cfg.bestofk_judge_scorer_index, + max_false_positives=run_cfg.iterative_max_false_positives, + max_false_negatives=run_cfg.iterative_max_false_negatives, + carryforward_strategy=run_cfg.iterative_carryforward_strategy, + verbose=run_cfg.verbose, + scorer_postprocess=scorer_postprocess_fn, + explainer_postprocess=explainer_postprocess, + ) + + # Run orchestrator on each record + for record in dataset: + try: + result = await orchestrator(record) + explainer_postprocess(result) + except Exception as e: + logger.error(f"Orchestrator failed for {record.latent}: {repr(e)}") + + return # Skip normal pipeline for orchestrators - return result + elif not run_cfg.explainer == "none": if run_cfg.constructor_cfg.non_activating_source == "FAISS": explainer = ContrastiveExplainer( diff --git a/delphi/config.py b/delphi/config.py index de806157..a89889f5 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -141,11 +141,12 @@ class RunConfig(Serializable): models and 'openrouter' for API calls.""" explainer: str = field( - choices=["default", "none"], + choices=["default", "none", "bestofk", "iterative"], default="default", ) """Explainer to use for generating explanations. Options are 'default' for - the default single token explainer, and 'none' for no explanation generation.""" + the default single token explainer, 'none' for no explanation generation, + 'bestofk' for best-of-K sampling, and 'iterative' for iterative refinement.""" scorers: list[str] = list_field( choices=[ @@ -165,6 +166,30 @@ class RunConfig(Serializable): examples and highlights n_incorrect tokens. Active uses activating examples and highlights non-activating tokens.""" + # BestOfK explainer config + bestofk_num_explanations: int = field(default=5) + """Number of explanation candidates to generate for best-of-K selection.""" + + bestofk_judge_scorer_index: int = field(default=0) + """Index of the scorer to use for selecting the best explanation.""" + + bestofk_num_train_examples: int | None = field(default=20) + """Number of training examples to show per explanation. None uses all available.""" + + # Iterative explainer config + iterative_num_rounds: int = field(default=3) + """Number of refinement rounds for iterative explanation.""" + + iterative_max_false_positives: int = field(default=20) + """Maximum false positive examples to include in refinement prompts.""" + + iterative_max_false_negatives: int = field(default=20) + """Maximum false negative examples to include in refinement prompts.""" + + iterative_carryforward_strategy: Literal["best", "last"] = "last" + """Strategy for selecting final explanation: 'best' uses highest-scoring, + 'last' uses most recent refinement.""" + name: str = "" """The name of the run. Results are saved in a directory with this name.""" diff --git a/delphi/explainers/__init__.py b/delphi/explainers/__init__.py index 8cbc5579..1b85919a 100644 --- a/delphi/explainers/__init__.py +++ b/delphi/explainers/__init__.py @@ -1,6 +1,8 @@ +from .bestofk import BestOfKOrchestrator from .contrastive_explainer import ContrastiveExplainer from .default.default import DefaultExplainer from .explainer import Explainer, explanation_loader, random_explanation_loader +from .iterative import HillClimbingOrchestrator from .no_op_explainer import NoOpExplainer from .single_token_explainer import SingleTokenExplainer @@ -12,4 +14,6 @@ "random_explanation_loader", "ContrastiveExplainer", "NoOpExplainer", + "BestOfKOrchestrator", + "HillClimbingOrchestrator", ] diff --git a/delphi/explainers/bestofk.py b/delphi/explainers/bestofk.py new file mode 100644 index 00000000..b1dd208e --- /dev/null +++ b/delphi/explainers/bestofk.py @@ -0,0 +1,376 @@ +""" +Best-of-K Explainer Orchestrator + +Generates multiple explanation candidates and selects the best one based on scorer feedback. +This is a standalone orchestrator that manages its own pipeline internally. +""" + +import asyncio +import random +import re +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from statistics import fmean +from typing import Callable, Optional + +from delphi import logger +from delphi.explainers.explainer import ExplainerResult +from delphi.latents.latents import ActivatingExample, LatentRecord +from delphi.pipeline import Pipe, Pipeline, process_wrapper +from delphi.scorers.scorer import Scorer, ScorerResult + +from .default.prompt_builder import build_prompt + +# System prompt for generating multiple explanations in one shot +SYSTEM_BESTOFK_ONESHOT = """You are a meticulous AI researcher conducting an important investigation into patterns found in language. Your task is to analyze text and provide an explanation that thoroughly encapsulates possible patterns found in it. +Guidelines: + +You will be given a list of text examples on which special words are selected and between delimiters like <>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <>. How important each token is for the behavior is listed after each example in parentheses. + +- Try to produce a concise final description. Simply describe the text latents that are common in the examples, and what patterns you found. +- If the examples are uninformative, you don't need to mention them. Don't focus on giving examples of important tokens, but try to summarize the patterns found in the examples. +- Do not mention the marker tokens (<< >>) in your explanation. +- Do not make lists of possible explanations. Keep your explanations short and concise. +- You will be given a number telling you how many explanations you are to generate - these explanations should be meaningfully distinct attempts to encapsulate the pattern. They should not be too similar. +- The final part of your response must consist exclusively of formatted explanations, each explanation on a new line, starting with "[EXPLANATION]:" followed by the explanation. It is imperative that you follow this format as the text is to be processed programmatically. + +""" + + +@dataclass +class BestOfKOrchestrator: + """ + Generates K explanation candidates and selects the best based on scorer feedback. + + Unlike regular explainers, this orchestrator runs scorers internally to evaluate + candidates before returning the final result. + """ + + client: object + """LLM client for generating explanations.""" + + scorers_with_paths: list[tuple[Scorer, Path]] = field(default_factory=list) + """List of (scorer, output_path) tuples for evaluation.""" + + num_explanations: int = 5 + """Number of explanation candidates to generate.""" + + judge_scorer_index: int = 0 + """Index of scorer to use for selecting best explanation.""" + + num_train_examples: Optional[int] = 20 + """Number of training examples to show. None uses all available.""" + + temperature: float = 0.0 + """Sampling temperature for generation.""" + + threshold: float = 0.3 + """Activation threshold for highlighting tokens.""" + + activations: bool = True + """Whether to show activation values in prompts.""" + + is_multishot: bool = True + """If True, make K separate calls. If False, request K explanations in one call.""" + + verbose: bool = False + """Whether to log verbose output.""" + + scorer_preprocess: Optional[Callable] = None + """Preprocessing function applied before scoring.""" + + scorer_postprocess: Optional[Callable] = None + """Postprocessing function applied after scoring.""" + + async def __call__(self, record: LatentRecord) -> ExplainerResult: + """Generate K explanations, score them, and return the best.""" + + # Split into train/test pools + train_pool, test_activating, test_non_activating = self._split_train_test( + record + ) + + # Create clean record for scoring + clean_record = LatentRecord( + latent=record.latent, + train=train_pool, + test=test_activating, + not_active=test_non_activating, + explanation=record.explanation, + ) + + # Build prompt and generate explanations + messages = self._build_prompt(clean_record.train) + + if self.is_multishot: + # Make K separate calls + tasks = [ + self.client.generate(messages, temperature=self.temperature) + for _ in range(self.num_explanations) + ] + responses = await asyncio.gather(*tasks) + combined_text = "\n".join([r.text for r in responses]) + explanations = self._parse_multiple_explanations(combined_text) + else: + # Single call requesting K explanations + oneshot_messages = self._build_oneshot_prompt(clean_record.train) + response = await self.client.generate( + oneshot_messages, temperature=self.temperature + ) + explanations = self._parse_multiple_explanations(response.text) + + # Cap at requested number + explanations = explanations[: self.num_explanations] + + # Create ExplainerResult for each candidate + explainer_results = [] + for idx, explanation in enumerate(explanations): + result_record = LatentRecord( + latent=clean_record.latent, + train=clean_record.train, + test=clean_record.test, + not_active=clean_record.not_active, + explanation=explanation, + ) + explainer_results.append( + ExplainerResult(record=result_record, explanation=explanation) + ) + + if not explainer_results: + # Fallback if parsing failed + return ExplainerResult( + record=clean_record, explanation="Explanation could not be parsed." + ) + + # Score all candidates + scorer_results = await self._run_scorers(explainer_results) + + # Select best based on judge scorer + judge_results = [ + s[self.judge_scorer_index] + for s in scorer_results + if s[self.judge_scorer_index] + ] + best_idx = self._select_best_idx(judge_results) + + # Save best scores + for scorer_idx, (scorer, score_dir) in enumerate(self.scorers_with_paths): + if best_idx < len(scorer_results) and scorer_results[best_idx][scorer_idx]: + best_score = scorer_results[best_idx][scorer_idx] + if self.scorer_postprocess: + self.scorer_postprocess(best_score, score_dir=score_dir) + + return explainer_results[best_idx] + + def _split_train_test(self, record: LatentRecord) -> tuple[list, list, list]: + """Split record into train pool, test activating, and test non-activating.""" + return ( + list(record.train), + list(record.test), + list(record.not_active), + ) + + def _build_prompt(self, examples: list[ActivatingExample]) -> list[dict]: + """Build prompt from examples using upstream DefaultExplainer logic.""" + # Sample if needed + if self.num_train_examples and len(examples) > self.num_train_examples: + examples = random.sample(examples, self.num_train_examples) + + # Highlight examples + highlighted = [] + for example in examples: + str_toks = example.str_tokens + activations_list = example.activations.tolist() + highlighted.append(self._highlight(str_toks, activations_list)) + + if self.activations and example.normalized_activations is not None: + normalized = example.normalized_activations.tolist() + highlighted.append( + self._join_activations(str_toks, activations_list, normalized) + ) + + highlighted_str = "\n".join(highlighted) + return build_prompt(examples=highlighted_str, activations=self.activations) + + def _build_oneshot_prompt(self, examples: list[ActivatingExample]) -> list[dict]: + """Build prompt for one-shot multi-explanation generation.""" + if self.num_train_examples and len(examples) > self.num_train_examples: + examples = random.sample(examples, self.num_train_examples) + + highlighted = [] + for example in examples: + str_toks = example.str_tokens + activations_list = example.activations.tolist() + highlighted.append(self._highlight(str_toks, activations_list)) + + if self.activations and example.normalized_activations is not None: + normalized = example.normalized_activations.tolist() + highlighted.append( + self._join_activations(str_toks, activations_list, normalized) + ) + + highlighted_str = "\n".join(highlighted) + + messages = [{"role": "system", "content": SYSTEM_BESTOFK_ONESHOT}] + messages.append({"role": "user", "content": f"\n{highlighted_str}\n"}) + messages.append( + { + "role": "user", + "content": f"The number of explanations to generate is: {self.num_explanations}.", + } + ) + + return messages + + def _highlight(self, str_toks: list[str], activations: list[float]) -> str: + """Highlight tokens above threshold with << >> markers.""" + result = "" + threshold_val = max(activations) * self.threshold if activations else 0 + + i = 0 + while i < len(str_toks): + if activations[i] > threshold_val: + result += "<<" + while i < len(str_toks) and activations[i] > threshold_val: + result += str_toks[i] + i += 1 + result += ">>" + else: + result += str_toks[i] + i += 1 + + return result + + def _join_activations( + self, + str_toks: list[str], + token_activations: list[float], + normalized_activations: list[float], + ) -> str: + """Format activation values for display.""" + acts = "" + count = 0 + threshold_val = ( + max(token_activations) * self.threshold if token_activations else 0 + ) + + for str_tok, tok_act, norm_act in zip( + str_toks, token_activations, normalized_activations + ): + if tok_act > threshold_val: + if count > 10: + break + acts += f'("{str_tok}" : {int(norm_act)}), ' + count += 1 + + return "Activations: " + acts + + def _parse_multiple_explanations(self, text: str) -> list[str]: + """Parse multiple [EXPLANATION]: markers from text.""" + try: + matches = re.findall( + r"\[EXPLANATION\]:\s*(.*?)(?=\[EXPLANATION\]:|$)", text, re.DOTALL + ) + if matches: + cleaned = [m.strip() for m in matches if m.strip()] + return cleaned if cleaned else ["Explanation could not be parsed."] + return ["Explanation could not be parsed."] + except Exception as e: + logger.error(f"Explanation parsing failed: {repr(e)}") + return ["Explanation could not be parsed."] + + async def _run_scorers( + self, explainer_results: list[ExplainerResult] + ) -> list[list[Optional[ScorerResult]]]: + """Run all scorers on all explanation candidates.""" + num_scorers = len(self.scorers_with_paths) + all_results: list[list[Optional[ScorerResult]]] = [ + [None] * num_scorers for _ in range(len(explainer_results)) + ] + + def make_wrapper(scorer_idx: int): + scorer, score_dir = self.scorers_with_paths[scorer_idx] + return process_wrapper( + scorer, + preprocess=self.scorer_preprocess, + postprocess=( + partial( + self.scorer_postprocess or (lambda r, **_: r), + score_dir=score_dir, + ) + if self.scorer_postprocess + else None + ), + ) + + wrappers = [make_wrapper(idx) for idx in range(num_scorers)] + + async def generator(): + for result in explainer_results: + yield result + + pipeline = Pipeline(generator(), Pipe(*wrappers)) + subset_results = await pipeline.run() + + for pos, scorer_list in enumerate(subset_results): + for scorer_idx, scorer_result in enumerate(scorer_list): + all_results[pos][scorer_idx] = scorer_result + + return all_results + + def _select_best_idx(self, scorer_results: list[ScorerResult]) -> int: + """Select index of best explanation based on scorer results.""" + if not scorer_results: + return 0 + + scores = [] + for result in scorer_results: + if result is None: + scores.append(float("-inf")) + continue + scores.append(self._compute_score(result)) + + return max(range(len(scores)), key=lambda i: scores[i]) + + def _compute_score(self, result: ScorerResult) -> float: + """Compute score from scorer result (F1 for classifier, similarity for embedding).""" + samples = result.score or [] + if not samples: + return float("-inf") + + # Check if embedding scorer (has 'similarity' attribute) + if hasattr(samples[0], "similarity"): + return self._compute_embedding_score(samples) + + # Otherwise assume classifier output + return self._compute_f1_score(samples) + + def _compute_f1_score(self, samples) -> float: + """Compute F1 score from classifier outputs.""" + tp = fp = fn = 0 + for sample in samples: + if sample.correct: + if sample.activating: + tp += 1 + else: + if sample.activating: + fn += 1 + else: + fp += 1 + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + return ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + + def _compute_embedding_score(self, samples) -> float: + """Compute embedding score as difference of positive/negative similarities.""" + pos = [s.similarity for s in samples if s.activating] + neg = [s.similarity for s in samples if not s.activating] + if not pos or not neg: + return float("-inf") + return fmean(pos) - fmean(neg) diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 3c3a488d..bf3a725b 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -4,7 +4,7 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import NamedTuple +from typing import NamedTuple, Optional import aiofiles @@ -21,6 +21,9 @@ class ExplainerResult(NamedTuple): explanation: str """Generated explanation for latent.""" + duration: Optional[float] = None + """Time taken to generate the explanation in seconds.""" + @dataclass class Explainer(ABC): diff --git a/delphi/explainers/iterative/__init__.py b/delphi/explainers/iterative/__init__.py new file mode 100644 index 00000000..88ac3167 --- /dev/null +++ b/delphi/explainers/iterative/__init__.py @@ -0,0 +1,5 @@ +"""Iterative explanation refinement module.""" + +from .hill_climbing import HillClimbingOrchestrator + +__all__ = ["HillClimbingOrchestrator"] diff --git a/delphi/explainers/iterative/hill_climbing.py b/delphi/explainers/iterative/hill_climbing.py new file mode 100644 index 00000000..df018e89 --- /dev/null +++ b/delphi/explainers/iterative/hill_climbing.py @@ -0,0 +1,437 @@ +""" +Iterative Hill-Climbing Explainer Orchestrator + +Generates explanations iteratively, using scorer feedback (false positives/negatives) +to refine explanations over multiple rounds. +""" + +import random +import re +from dataclasses import dataclass, field +from pathlib import Path +from statistics import fmean +from typing import Callable, Literal, Optional + +import torch + +from delphi import logger +from delphi.explainers.explainer import ExplainerResult +from delphi.latents import ( + ActivatingExample, + Example, + LatentRecord, + NonActivatingExample, +) +from delphi.scorers.scorer import Scorer, ScorerResult + +# System prompt for iterative refinement +SYSTEM_ITERATIVE = """You are a meticulous AI researcher conducting an important investigation into patterns found in language. Your task is to analyze text and provide an explanation that thoroughly encapsulates possible patterns found in it. +Guidelines: + +You will be given a list of text examples on which special words are selected and between delimiters like <>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <>. How important each token is for the behavior is listed after each example in parentheses. +Your task is to provide a necessary and sufficient explanation that predicts when the pattern is present (i.e., what condition causes the tokens to be marked) + +- Try to produce a concise final description. Simply describe the text latents that are common in the examples, and what patterns you found. +- If the examples are uninformative, you don't need to mention them. Don't focus on giving examples of important tokens, but try to summarize the patterns found in the examples. +- Do not mention the marker tokens (<< >>) in your explanation. +- Do not make lists of possible explanations. Keep your explanations short and concise. +- You may be given a previous attempted explanation of the pattern, along with some false-negative or false-positives. Please use these to refine the explanation - do NOT return the same explanation, instead refine it based on the new data. +- If iterating on a given explanation, the examples will be labeled according to type (normal, false-negative, false-positive - e.g. a false-positive example is an example that was incorrectly identified as having the pattern based on the explanation shown. +- If you are not given a prior explanation, examples will not be labeled and are all normal examples known to activate the pattern. +- The last line of your response must be the explanation, beginning with "[EXPLANATION]:" followed by the explanation with no line breaks. Your answer will be processed programmatically so please comply with these rules. + +""" + + +@dataclass +class HillClimbingOrchestrator: + """ + Iteratively refines explanations using scorer feedback over multiple rounds. + + Each round: + 1. Generate/refine explanation based on examples and previous errors + 2. Score the explanation + 3. Extract false positives/negatives for next round + 4. Repeat + + Final result selected by 'best' (highest score) or 'last' (most recent) strategy. + """ + + client: object + """LLM client for generating explanations.""" + + scorers_with_paths: list[tuple[Scorer, Path]] = field(default_factory=list) + """List of (scorer, output_path) tuples for evaluation.""" + + num_rounds: int = 3 + """Number of refinement rounds.""" + + judge_scorer_index: int = 0 + """Index of scorer to use for selecting best explanation.""" + + max_false_positives: int = 20 + """Maximum false positive examples to include in refinement prompts.""" + + max_false_negatives: int = 20 + """Maximum false negative examples to include in refinement prompts.""" + + carryforward_strategy: Literal["best", "last"] = "last" + """Strategy for final selection: 'best' or 'last'.""" + + num_train_examples_per_round: int = 20 + """Number of training examples to show each round.""" + + threshold: float = 0.3 + """Activation threshold for highlighting tokens.""" + + activations: bool = True + """Whether to show activation values in prompts.""" + + temperature: float = 0.0 + """Sampling temperature for generation.""" + + verbose: bool = False + """Whether to log verbose output.""" + + scorer_postprocess: Optional[Callable] = None + """Postprocessing function for scorer results.""" + + explainer_postprocess: Optional[Callable] = None + """Postprocessing function for explainer results.""" + + async def __call__(self, record: LatentRecord) -> ExplainerResult: + """Run iterative refinement and return best/last explanation.""" + + # Split data + ( + train_pool, + test_activating, + test_non_activating, + holdout_activating, + holdout_non_activating, + ) = self._split_train_test_holdout(record) + + explanations = [] + all_holdout_scores = [] + wrong_examples = [] + current_explanation = None + + for round_idx in range(self.num_rounds): + # Sample training examples + if len(train_pool) > self.num_train_examples_per_round: + sampled_train = random.sample( + train_pool, self.num_train_examples_per_round + ) + else: + sampled_train = train_pool + + # Build prompt + if current_explanation is None: + # Initial prompt + messages = self._build_initial_prompt(sampled_train) + else: + # Refinement prompt with FP/FN feedback + messages = self._build_refinement_prompt( + sampled_train, current_explanation, wrong_examples + ) + + # Generate explanation + response = await self.client.generate( + messages, temperature=self.temperature + ) + explanation_text = self._parse_explanation(response.text) + + # Create result + result_record = LatentRecord( + latent=record.latent, + train=sampled_train, + test=test_activating, + not_active=test_non_activating, + explanation=explanation_text, + ) + result = ExplainerResult(record=result_record, explanation=explanation_text) + explanations.append(result) + + if self.explainer_postprocess: + self.explainer_postprocess(result, is_final=False) + + # Score on test set + test_scorer_results = await self._run_scorers(result_record) + + # Score on holdout set + holdout_record = LatentRecord( + latent=record.latent, + train=sampled_train, + test=holdout_activating, + not_active=holdout_non_activating, + explanation=explanation_text, + ) + holdout_scorer_results = await self._run_scorers(holdout_record) + all_holdout_scores.append(holdout_scorer_results) + + # Save intermediate scores + for scorer_idx, (_, score_dir) in enumerate(self.scorers_with_paths): + if self.scorer_postprocess and holdout_scorer_results[scorer_idx]: + self.scorer_postprocess( + holdout_scorer_results[scorer_idx], + score_dir=score_dir, + round_idx=round_idx, + ) + + # Extract wrong examples for next round + wrong_examples = self._extract_wrong_examples(test_scorer_results) + + # Update current explanation for next round + current_explanation = explanation_text + + # Select final explanation + if self.carryforward_strategy == "best": + best_idx = self._select_best_idx(all_holdout_scores) + final_result = explanations[best_idx] + final_scores = all_holdout_scores[best_idx] + else: + final_result = explanations[-1] + final_scores = all_holdout_scores[-1] + + # Save final scores + for scorer_idx, (_, score_dir) in enumerate(self.scorers_with_paths): + if self.scorer_postprocess and final_scores[scorer_idx]: + self.scorer_postprocess( + final_scores[scorer_idx], score_dir=score_dir, is_final=True + ) + + if self.explainer_postprocess: + self.explainer_postprocess(final_result, is_final=True) + + return final_result + + def _split_train_test_holdout(self, record: LatentRecord): + """Split record into train pool, test set, and holdout set.""" + train_pool = list(record.train) + test_activating = list(record.train) # Use train for FP/FN collection + test_non_activating = list(record.not_active) + holdout_activating = list(record.test) # Use test as holdout + holdout_non_activating = list(record.not_active) + + return ( + train_pool, + test_activating, + test_non_activating, + holdout_activating, + holdout_non_activating, + ) + + def _build_initial_prompt(self, examples: list[ActivatingExample]) -> list[dict]: + """Build initial prompt without prior explanation.""" + highlighted = self._format_examples(examples) + messages = [{"role": "system", "content": SYSTEM_ITERATIVE}] + messages.append({"role": "user", "content": f"\n{highlighted}\n"}) + return messages + + def _build_refinement_prompt( + self, + examples: list[ActivatingExample], + current_explanation: str, + wrong_examples: list[Example], + ) -> list[dict]: + """Build refinement prompt with FP/FN feedback.""" + # Format normal examples + highlighted = self._format_examples(examples) + + # Separate FP/FN + false_positives = [] + false_negatives = [] + for ex in wrong_examples: + if ex.activations.max() > 0: + false_negatives.append(ex) + else: + false_positives.append(ex) + + fp_str = self._format_examples( + false_positives[: self.max_false_positives], show_activations=False + ) + fn_str = self._format_examples(false_negatives[: self.max_false_negatives]) + + messages = [{"role": "system", "content": SYSTEM_ITERATIVE}] + messages.append( + {"role": "user", "content": f"Normal examples:\n{highlighted}\n"} + ) + messages.append( + { + "role": "user", + "content": ( + f"Current explanation: {current_explanation}\n\n" + f"False negatives:\n{fn_str}\n" + f"False positives:\n{fp_str}\n" + ), + } + ) + + return messages + + def _format_examples( + self, examples: list[Example], show_activations: bool = True + ) -> str: + """Format examples with highlighting.""" + parts = [] + for i, example in enumerate(examples): + str_toks = example.str_tokens + acts = example.activations.tolist() + highlighted = self._highlight(str_toks, acts) + parts.append(f"Example {i}: {highlighted}") + + if show_activations and self.activations: + if ( + hasattr(example, "normalized_activations") + and example.normalized_activations is not None + ): + norm_acts = example.normalized_activations.tolist() + parts.append(self._join_activations(str_toks, acts, norm_acts)) + + return "\n".join(parts) + + def _highlight(self, str_toks: list[str], activations: list[float]) -> str: + """Highlight tokens above threshold.""" + result = "" + threshold_val = max(activations) * self.threshold if activations else 0 + + i = 0 + while i < len(str_toks): + if activations[i] > threshold_val: + result += "<<" + while i < len(str_toks) and activations[i] > threshold_val: + result += str_toks[i] + i += 1 + result += ">>" + else: + result += str_toks[i] + i += 1 + + return result + + def _join_activations( + self, + str_toks: list[str], + token_activations: list[float], + normalized_activations: list[float], + ) -> str: + """Format activation values.""" + acts = "" + count = 0 + threshold_val = ( + max(token_activations) * self.threshold if token_activations else 0 + ) + + for str_tok, tok_act, norm_act in zip( + str_toks, token_activations, normalized_activations + ): + if tok_act > threshold_val: + if count > 10: + break + acts += f'("{str_tok}" : {int(norm_act)}), ' + count += 1 + + return "Activations: " + acts + + def _parse_explanation(self, text: str) -> str: + """Parse [EXPLANATION]: from response.""" + try: + match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) + if match: + return match.group(1).strip() + return "Explanation could not be parsed." + except Exception as e: + logger.error(f"Explanation parsing failed: {repr(e)}") + return "Explanation could not be parsed." + + async def _run_scorers(self, record: LatentRecord) -> list[Optional[ScorerResult]]: + """Run all scorers on record.""" + results = [] + for scorer, _ in self.scorers_with_paths: + try: + result = await scorer(record) + results.append(result) + except Exception as e: + logger.error(f"Scorer failed: {repr(e)}") + results.append(None) + return results + + def _extract_wrong_examples( + self, scorer_results: list[Optional[ScorerResult]] + ) -> list[Example]: + """Extract incorrectly classified examples from scorer results.""" + wrong = [] + for result in scorer_results: + if result is None or not result.score: + continue + + for sample in result.score: + if not hasattr(sample, "correct") or sample.correct: + continue + + # Create example from wrong prediction + if sample.activating: + ex = ActivatingExample( + tokens=torch.tensor(0), + activations=torch.tensor(sample.activations), + str_tokens=sample.str_tokens, + normalized_activations=torch.tensor(sample.activations), + ) + else: + ex = NonActivatingExample( + tokens=torch.tensor(0), + activations=torch.tensor(sample.activations), + str_tokens=sample.str_tokens, + ) + + # Deduplicate + if not any(w.str_tokens == ex.str_tokens for w in wrong): + wrong.append(ex) + + return wrong + + def _select_best_idx(self, all_scores: list[list[Optional[ScorerResult]]]) -> int: + """Select index of best round based on judge scorer.""" + scores = [] + for round_scores in all_scores: + if round_scores[self.judge_scorer_index]: + score = self._compute_score(round_scores[self.judge_scorer_index]) + else: + score = float("-inf") + scores.append(score) + + return max(range(len(scores)), key=lambda i: scores[i]) + + def _compute_score(self, result: ScorerResult) -> float: + """Compute score from scorer result.""" + samples = result.score or [] + if not samples: + return float("-inf") + + if hasattr(samples[0], "similarity"): + # Embedding scorer + pos = [s.similarity for s in samples if s.activating] + neg = [s.similarity for s in samples if not s.activating] + if not pos or not neg: + return float("-inf") + return fmean(pos) - fmean(neg) + + # Classifier scorer (F1) + tp = fp = fn = 0 + for sample in samples: + if sample.correct: + if sample.activating: + tp += 1 + else: + if sample.activating: + fn += 1 + else: + fp += 1 + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + return ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0 + ) diff --git a/delphi/scorers/scorer.py b/delphi/scorers/scorer.py index fa5a0ae5..1ae3bbec 100644 --- a/delphi/scorers/scorer.py +++ b/delphi/scorers/scorer.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, NamedTuple +from typing import Any, NamedTuple, Optional from ..latents.latents import LatentRecord @@ -11,6 +11,9 @@ class ScorerResult(NamedTuple): score: Any """Generated score for latent.""" + duration: Optional[float] = None + """Time taken to generate the score in seconds.""" + class Scorer(ABC): @abstractmethod