Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 84 additions & 5 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from delphi.clients import Offline, OpenRouter
from delphi.config import RunConfig
from delphi.explainers import ContrastiveExplainer, DefaultExplainer, NoOpExplainer
from delphi.explainers.bestofk import BestOfKOrchestrator
from delphi.explainers.explainer import ExplainerResult
from delphi.explainers.iterative import HillClimbingOrchestrator
from delphi.latents import LatentCache, LatentDataset
from delphi.latents.neighbours import NeighbourCalculator
from delphi.log.result_analysis import log_results
Expand Down Expand Up @@ -171,13 +173,90 @@ async def process_cache(
f"Explainer provider {run_cfg.explainer_provider} not supported"
)

if not run_cfg.explainer == "none":
# Define postprocessors used by all explainer modes
def explainer_postprocess(result, is_final=True):
with open(explanations_path / f"{result.record.latent}.txt", "wb") as f:
f.write(orjson.dumps(result.explanation))
return result

def explainer_postprocess(result):
with open(explanations_path / f"{result.record.latent}.txt", "wb") as f:
f.write(orjson.dumps(result.explanation))
def scorer_postprocess_fn(result, score_dir, round_idx=None, is_final=False):
safe_latent = str(result.record.latent).replace("/", "--")
if round_idx is not None and not is_final:
filename = f"{safe_latent}_round{round_idx}.txt"
else:
filename = f"{safe_latent}.txt"
with open(score_dir / filename, "wb") as f:
f.write(orjson.dumps(result.score))

# Handle BestOfK and Iterative orchestrators
if run_cfg.explainer in ("bestofk", "iterative"):
# Build scorers for orchestrators
scorers_with_paths = []
for scorer_name in run_cfg.scorers:
scorer_path = scores_path / scorer_name
scorer_path.mkdir(parents=True, exist_ok=True)

if scorer_name == "simulation":
scorer = OpenAISimulator(
llm_client,
tokenizer=tokenizer,
all_at_once=isinstance(llm_client, Offline),
)
elif scorer_name == "fuzz":
scorer = FuzzingScorer(
llm_client,
n_examples_shown=run_cfg.num_examples_per_scorer_prompt,
verbose=run_cfg.verbose,
log_prob=run_cfg.log_probs,
fuzz_type=run_cfg.fuzz_type,
)
elif scorer_name == "detection":
scorer = DetectionScorer(
llm_client,
n_examples_shown=run_cfg.num_examples_per_scorer_prompt,
verbose=run_cfg.verbose,
log_prob=run_cfg.log_probs,
)
else:
raise ValueError(f"Scorer {scorer_name} not supported")

scorers_with_paths.append((scorer, scorer_path))

if run_cfg.explainer == "bestofk":
orchestrator = BestOfKOrchestrator(
client=llm_client,
scorers_with_paths=scorers_with_paths,
num_explanations=run_cfg.bestofk_num_explanations,
judge_scorer_index=run_cfg.bestofk_judge_scorer_index,
num_train_examples=run_cfg.bestofk_num_train_examples,
verbose=run_cfg.verbose,
scorer_postprocess=scorer_postprocess_fn,
)
else: # iterative
orchestrator = HillClimbingOrchestrator(
client=llm_client,
scorers_with_paths=scorers_with_paths,
num_rounds=run_cfg.iterative_num_rounds,
judge_scorer_index=run_cfg.bestofk_judge_scorer_index,
max_false_positives=run_cfg.iterative_max_false_positives,
max_false_negatives=run_cfg.iterative_max_false_negatives,
carryforward_strategy=run_cfg.iterative_carryforward_strategy,
verbose=run_cfg.verbose,
scorer_postprocess=scorer_postprocess_fn,
explainer_postprocess=explainer_postprocess,
)

# Run orchestrator on each record
for record in dataset:
try:
result = await orchestrator(record)
explainer_postprocess(result)
except Exception as e:
logger.error(f"Orchestrator failed for {record.latent}: {repr(e)}")

return # Skip normal pipeline for orchestrators

return result
elif not run_cfg.explainer == "none":

if run_cfg.constructor_cfg.non_activating_source == "FAISS":
explainer = ContrastiveExplainer(
Expand Down
29 changes: 27 additions & 2 deletions delphi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ class RunConfig(Serializable):
models and 'openrouter' for API calls."""

explainer: str = field(
choices=["default", "none"],
choices=["default", "none", "bestofk", "iterative"],
default="default",
)
"""Explainer to use for generating explanations. Options are 'default' for
the default single token explainer, and 'none' for no explanation generation."""
the default single token explainer, 'none' for no explanation generation,
'bestofk' for best-of-K sampling, and 'iterative' for iterative refinement."""

scorers: list[str] = list_field(
choices=[
Expand All @@ -165,6 +166,30 @@ class RunConfig(Serializable):
examples and highlights n_incorrect tokens. Active uses activating examples
and highlights non-activating tokens."""

# BestOfK explainer config
bestofk_num_explanations: int = field(default=5)
"""Number of explanation candidates to generate for best-of-K selection."""

bestofk_judge_scorer_index: int = field(default=0)
"""Index of the scorer to use for selecting the best explanation."""

bestofk_num_train_examples: int | None = field(default=20)
"""Number of training examples to show per explanation. None uses all available."""

# Iterative explainer config
iterative_num_rounds: int = field(default=3)
"""Number of refinement rounds for iterative explanation."""

iterative_max_false_positives: int = field(default=20)
"""Maximum false positive examples to include in refinement prompts."""

iterative_max_false_negatives: int = field(default=20)
"""Maximum false negative examples to include in refinement prompts."""

iterative_carryforward_strategy: Literal["best", "last"] = "last"
"""Strategy for selecting final explanation: 'best' uses highest-scoring,
'last' uses most recent refinement."""

name: str = ""
"""The name of the run. Results are saved in a directory with this name."""

Expand Down
4 changes: 4 additions & 0 deletions delphi/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .bestofk import BestOfKOrchestrator
from .contrastive_explainer import ContrastiveExplainer
from .default.default import DefaultExplainer
from .explainer import Explainer, explanation_loader, random_explanation_loader
from .iterative import HillClimbingOrchestrator
from .no_op_explainer import NoOpExplainer
from .single_token_explainer import SingleTokenExplainer

Expand All @@ -12,4 +14,6 @@
"random_explanation_loader",
"ContrastiveExplainer",
"NoOpExplainer",
"BestOfKOrchestrator",
"HillClimbingOrchestrator",
]
Loading