diff --git a/.gitignore b/.gitignore index ecc3046..3df7604 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ pnpm-debug.log* # Editor directories and files .idea .vscode +.claude *.suo *.ntvs* *.njsproj @@ -24,3 +25,6 @@ pnpm-debug.log* __pycache__ *egg-info *pyc + +# Test fixtures cache (downloaded experiment data) +tests/fixtures/.cache/ diff --git a/README.md b/README.md index ba30413..04de234 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,17 @@ poetry run align-app --scenarios /data/shared/evaluation_outputs/Ph2-JulyCollab/ poetry run align-app --scenarios /path/to/scenarios1.json /path/to/scenarios2.json /path/to/scenarios_dir ``` +### Load Experiment Results + +You can load pre-computed experiment results using the `--experiments` flag. This extracts unique ADM configurations from experiment directories and adds them to the decider dropdown: + +```console +# Use test fixtures (download first by running: poetry run pytest tests/test_experiment_deciders.py -k download) +poetry run align-app --experiments tests/fixtures/.cache/experiments +``` + +The experiment directory should contain subdirectories with `.hydra/config.yaml` and `input_output.json` files from align-system experiment runs. + ### Optionally Configure Network Port or Host The web server is from Trame. To configure the port, use the `--port` or `-p` arg diff --git a/align_app/adm/config.py b/align_app/adm/config.py index 7e4da6a..4871d7d 100644 --- a/align_app/adm/config.py +++ b/align_app/adm/config.py @@ -3,6 +3,7 @@ from pathlib import Path import align_system from align_app.adm.hydra_config_loader import load_adm_config +from align_app.adm.experiment_config_loader import load_experiment_adm_config from align_app.utils.utils import merge_dicts @@ -28,6 +29,9 @@ def get_decider_config( Merges base decider config with app-level overrides. Two-layer merge: base YAML config + (config_overrides + dataset_overrides) + For experiment configs (experiment_config: True), loads pre-resolved YAML directly. + For edited configs (edited_config: True), returns the stored resolved_config directly. + Args: probe_id: The probe ID to get config for all_deciders: Dict of all available deciders @@ -41,35 +45,38 @@ def get_decider_config( if not decider_cfg: return None - config_path = decider_cfg["config_path"] + is_edited_config = decider_cfg.get("edited_config", False) + is_experiment_config = decider_cfg.get("experiment_config", False) + + if is_edited_config: + config = copy.deepcopy(decider_cfg["resolved_config"]) + if llm_backbone and "structured_inference_engine" in config: + config["structured_inference_engine"]["model_name"] = llm_backbone + return config - # Layer 1: Load base config from align-system YAML - full_cfg = load_adm_config( - config_path, - str(base_align_system_config_dir), - ) - decider_base = full_cfg.get("adm", {}) + # Layer 1: Load base config - either pre-resolved experiment YAML or Hydra compose. + # Both produce same structure with ${ref:...} that initialize_with_custom_references handles. + if is_experiment_config: + experiment_path = Path(decider_cfg["experiment_path"]) + decider_base = load_experiment_adm_config(experiment_path) or {} + else: + config_path = decider_cfg["config_path"] + full_cfg = load_adm_config( + config_path, + str(base_align_system_config_dir), + ) + decider_base = full_cfg.get("adm", {}) # Layer 2: Prepare app-level overrides config_overrides = decider_cfg.get("config_overrides", {}) dataset_overrides = decider_cfg.get("dataset_overrides", {}).get(dataset_name, {}) - # Extract metadata fields from decider entry - metadata = { - k: v - for k, v in decider_cfg.items() - if k in ["llm_backbones", "model_path_keys"] - } - - # Single deep merge: base + config_overrides + dataset_overrides + metadata + # Deep merge: base + config_overrides + dataset_overrides merged_config = copy.deepcopy(decider_base) merged_config = merge_dicts(merged_config, config_overrides) merged_config = merge_dicts(merged_config, dataset_overrides) - merged_config = merge_dicts(merged_config, metadata) - if llm_backbone: - merged_config["llm_backbone"] = llm_backbone - if "structured_inference_engine" in merged_config: - merged_config["structured_inference_engine"]["model_name"] = llm_backbone + if llm_backbone and "structured_inference_engine" in merged_config: + merged_config["structured_inference_engine"]["model_name"] = llm_backbone return merged_config diff --git a/align_app/adm/decider/executor.py b/align_app/adm/decider/executor.py index b2cc800..7d66180 100644 --- a/align_app/adm/decider/executor.py +++ b/align_app/adm/decider/executor.py @@ -1,6 +1,4 @@ from typing import Any, Tuple -import gc -import torch from functools import partial from align_system.utils.hydra_utils import initialize_with_custom_references from align_system.utils.hydrate_state import p2triage_hydrate_scenario_state @@ -95,8 +93,17 @@ def instantiate_adm(decider_config): adm = initialize_with_custom_references({"adm": decider_config})["adm"] - def cleanup(_): - gc.collect() - torch.cuda.empty_cache() + def cleanup(model): + if hasattr(model, "instance"): + instance = model.instance + if hasattr(instance, "steps"): + for step in instance.steps: + if hasattr(step, "structured_inference_engine"): + engine = step.structured_inference_engine + if hasattr(engine, "model") and engine.model is not None: + del engine.model + if hasattr(engine, "sampler"): + del engine.sampler + instance.steps.clear() return partial(choose_action, adm), partial(cleanup, adm) diff --git a/align_app/adm/decider/tests/conftest.py b/align_app/adm/decider/tests/conftest.py index 8c40467..c2ffe75 100644 --- a/align_app/adm/decider/tests/conftest.py +++ b/align_app/adm/decider/tests/conftest.py @@ -60,7 +60,6 @@ def resolved_random_config(): "${ref:adm.step_definitions.populate_choice_info}", ], }, - "model_path_keys": ["structured_inference_engine", "model_name"], } diff --git a/align_app/adm/decider/tests/test_worker.py b/align_app/adm/decider/tests/test_worker.py index b9dc1a7..56ceddb 100644 --- a/align_app/adm/decider/tests/test_worker.py +++ b/align_app/adm/decider/tests/test_worker.py @@ -1,5 +1,113 @@ import multiprocessing as mp from align_app.adm.decider.types import DeciderParams +from align_app.adm.decider.worker import extract_cache_key + + +def mock_worker_with_event_tracking(task_queue, result_queue, event_queue): + """Worker that uses a mock ADM and tracks load/cleanup events via a Queue. + + This implements the FIXED decider_worker_func logic that cleans up + old models before loading new ones. + Events are sent as tuples: ('load', key) or ('cleanup', key) + """ + import hashlib + import json + import logging + import traceback + from typing import Dict, Tuple, Callable + from align_utils.models import ADMResult, Decision, ChoiceInfo + from align_app.adm.decider.types import DeciderParams + + root_logger = logging.getLogger() + root_logger.setLevel("WARNING") + + def extract_key(config): + cache_str = json.dumps(config, sort_keys=True) + return hashlib.md5(cache_str.encode()).hexdigest() + + def mock_instantiate_adm(config): + cache_key = extract_key(config) + event_queue.put(("load", cache_key)) + + def choose_action(params): + return ADMResult( + decision=Decision(unstructured="test", justification="test"), + choice_info=ChoiceInfo( + choice_id="test", + choice_kdma_association=[], + choice_description="", + ), + ) + + def cleanup(): + event_queue.put(("cleanup", cache_key)) + + return choose_action, cleanup + + model_cache: Dict[str, Tuple[Callable, Callable]] = {} + + try: + for task in iter(task_queue.get, None): + try: + params: DeciderParams = task + cache_key = extract_key(params.resolved_config) + + if cache_key not in model_cache: + old_cleanups = [cleanup for _, (_, cleanup) in model_cache.items()] + model_cache.clear() + for cleanup in old_cleanups: + cleanup() + + choose_action_func, cleanup_func = mock_instantiate_adm( + params.resolved_config + ) + model_cache[cache_key] = (choose_action_func, cleanup_func) + else: + choose_action_func, _ = model_cache[cache_key] + + result: ADMResult = choose_action_func(params) + result_queue.put(result) + + except (KeyboardInterrupt, SystemExit): + break + except Exception as e: + error_msg = f"{str(e)}\n{traceback.format_exc()}" + result_queue.put(Exception(error_msg)) + finally: + for _, (_, cleanup_func) in model_cache.items(): + try: + cleanup_func() + except Exception: + pass + event_queue.put(None) + + +class TestExtractCacheKey: + def test_same_config_produces_same_key(self): + config = {"model_name": "test-model", "temperature": 0.7} + key1 = extract_cache_key(config) + key2 = extract_cache_key(config) + assert key1 == key2 + + def test_different_configs_produce_different_keys(self): + config1 = {"model_name": "test-model", "temperature": 0.7} + config2 = {"model_name": "test-model", "temperature": 0.8} + key1 = extract_cache_key(config1) + key2 = extract_cache_key(config2) + assert key1 != key2 + + def test_same_model_different_settings_produce_different_keys(self): + config1 = { + "structured_inference_engine": {"model_name": "same-model"}, + "setting_a": "value1", + } + config2 = { + "structured_inference_engine": {"model_name": "same-model"}, + "setting_a": "value2", + } + key1 = extract_cache_key(config1) + key2 = extract_cache_key(config2) + assert key1 != key2 class TestDeciderWorker: @@ -166,3 +274,142 @@ def test_worker_shuts_down_cleanly(self, worker_queues): worker_process.join(timeout=5) assert not worker_process.is_alive() + + +def collect_events_from_queue(event_queue, timeout=1.0): + """Collect all events from queue until None sentinel.""" + import queue + + events = [] + while True: + try: + event = event_queue.get(timeout=timeout) + if event is None: + break + events.append(event) + except queue.Empty: + break + return events + + +class TestCleanupOnADMSwitch: + """Tests for GPU memory cleanup when switching between ADMs.""" + + def test_cleanup_called_before_loading_new_adm( + self, + scenario_input, + alignment_target_baseline, + ): + """When loading a new ADM config, cleanup should be called for the old one BEFORE loading the new one. + + This ensures GPU memory is freed before allocating memory for the new model. + + Expected event order: + 1. ('load', key_a) + 2. ('cleanup', key_a) <-- cleanup BEFORE loading new model + 3. ('load', key_b) + 4. ('cleanup', key_b) <-- final cleanup on shutdown + + Buggy behavior would be: + 1. ('load', key_a) + 2. ('load', key_b) <-- no cleanup, models accumulate! + 3. ('cleanup', key_a) + 4. ('cleanup', key_b) + """ + ctx = mp.get_context("spawn") + task_queue = ctx.Queue() + result_queue = ctx.Queue() + event_queue = ctx.Queue() + + worker_process = ctx.Process( + target=mock_worker_with_event_tracking, + args=(task_queue, result_queue, event_queue), + ) + worker_process.start() + + config_a = {"model": "model_a", "setting": "value_a"} + config_b = {"model": "model_b", "setting": "value_b"} + + params_a = DeciderParams( + scenario_input=scenario_input, + alignment_target=alignment_target_baseline, + resolved_config=config_a, + ) + task_queue.put(params_a) + result_queue.get(timeout=5) + + params_b = DeciderParams( + scenario_input=scenario_input, + alignment_target=alignment_target_baseline, + resolved_config=config_b, + ) + task_queue.put(params_b) + result_queue.get(timeout=5) + + task_queue.put(None) + worker_process.join(timeout=5) + + events = collect_events_from_queue(event_queue) + + load_events = [(i, e) for i, e in enumerate(events) if e[0] == "load"] + assert len(load_events) == 2, f"Expected 2 load events, got {len(load_events)}" + + load_a_idx = load_events[0][0] + load_b_idx = load_events[1][0] + key_a = load_events[0][1][1] + + cleanup_a_before_load_b = any( + e[0] == "cleanup" and e[1] == key_a + for e in events[load_a_idx + 1 : load_b_idx] + ) + + assert cleanup_a_before_load_b, ( + f"Expected cleanup of first model BEFORE loading second model.\n" + f"Events: {events}\n" + f"This indicates GPU memory is not being freed when switching ADMs." + ) + + def test_no_cleanup_when_same_adm_reused( + self, + scenario_input, + alignment_target_baseline, + ): + """When using the same ADM config, no cleanup should happen during operation.""" + ctx = mp.get_context("spawn") + task_queue = ctx.Queue() + result_queue = ctx.Queue() + event_queue = ctx.Queue() + + worker_process = ctx.Process( + target=mock_worker_with_event_tracking, + args=(task_queue, result_queue, event_queue), + ) + worker_process.start() + + config_a = {"model": "model_a", "setting": "value_a"} + + params1 = DeciderParams( + scenario_input=scenario_input, + alignment_target=alignment_target_baseline, + resolved_config=config_a, + ) + task_queue.put(params1) + result_queue.get(timeout=5) + + params2 = DeciderParams( + scenario_input=scenario_input, + alignment_target=alignment_target_baseline, + resolved_config=config_a, + ) + task_queue.put(params2) + result_queue.get(timeout=5) + + task_queue.put(None) + worker_process.join(timeout=5) + + events = collect_events_from_queue(event_queue) + + load_events = [e for e in events if e[0] == "load"] + assert len(load_events) == 1, ( + f"Expected only 1 load event when reusing same config, got {len(load_events)}" + ) diff --git a/align_app/adm/decider/worker.py b/align_app/adm/decider/worker.py index 8c5ec20..c80af6b 100644 --- a/align_app/adm/decider/worker.py +++ b/align_app/adm/decider/worker.py @@ -1,3 +1,4 @@ +import gc import hashlib import json import logging @@ -10,19 +11,8 @@ def extract_cache_key(resolved_config: Dict[str, Any]) -> str: - llm_backbone = resolved_config.get("llm_backbone", {}) - model_path_keys = resolved_config.get("model_path_keys", []) - - cache_parts = [] - for key in model_path_keys: - if key in llm_backbone: - cache_parts.append(f"{key}={llm_backbone[key]}") - - if not cache_parts: - cache_str = json.dumps(resolved_config, sort_keys=True) - return hashlib.md5(cache_str.encode()).hexdigest() - - return "_".join(cache_parts) + cache_str = json.dumps(resolved_config, sort_keys=True) + return hashlib.md5(cache_str.encode()).hexdigest() def decider_worker_func(task_queue: Queue, result_queue: Queue): @@ -38,6 +28,19 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue): cache_key = extract_cache_key(params.resolved_config) if cache_key not in model_cache: + old_cleanups = [cleanup for _, (_, cleanup) in model_cache.items()] + model_cache.clear() + for cleanup in old_cleanups: + cleanup() + del old_cleanups + + import torch + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + choose_action_func, cleanup_func = instantiate_adm( params.resolved_config ) diff --git a/align_app/adm/decider_definitions.py b/align_app/adm/decider_definitions.py index 80ee3c1..38a02a9 100644 --- a/align_app/adm/decider_definitions.py +++ b/align_app/adm/decider_definitions.py @@ -103,7 +103,6 @@ def create_decider_entry(config_path, overrides={}): return { "config_path": config_path, "llm_backbones": LLM_BACKBONES, - "model_path_keys": ["structured_inference_engine", "model_name"], "dataset_overrides": {}, **overrides, } @@ -113,9 +112,9 @@ def create_decider_entry(config_path, overrides={}): "phase2_pipeline_zeroshot_comparative_regression": create_decider_entry( "adm/phase2_pipeline_zeroshot_comparative_regression.yaml", { + "max_alignment_attributes": 10, "config_overrides": { "comparative_regression_choice_schema": {"reasoning_max_length": -1}, - "max_alignment_attributes": 10, }, "system_prompt_generator": _generate_comparative_regression_pipeline_system_prompt, }, @@ -123,9 +122,9 @@ def create_decider_entry(config_path, overrides={}): "phase2_pipeline_fewshot_comparative_regression": create_decider_entry( "adm/phase2_pipeline_fewshot_comparative_regression.yaml", { + "max_alignment_attributes": 10, "config_overrides": { "comparative_regression_choice_schema": {"reasoning_max_length": -1}, - "max_alignment_attributes": 10, "step_definitions": { "regression_icl": { "icl_generator_partial": { @@ -179,13 +178,12 @@ def create_runtime_decider_entry(config_path): ) -def get_all_deciders(config_paths=[]): - """Get all deciders, merging runtime configs from paths with base deciders.""" - runtime_deciders = { +def get_runtime_deciders(config_paths): + """Get runtime deciders from CLI config paths.""" + return { Path(config_path).stem: create_runtime_decider_entry(config_path) for config_path in config_paths } - return {**runtime_deciders, **_BASE_DECIDERS} def get_system_prompt( diff --git a/align_app/adm/decider_registry.py b/align_app/adm/decider_registry.py index 394a298..c79aceb 100644 --- a/align_app/adm/decider_registry.py +++ b/align_app/adm/decider_registry.py @@ -1,7 +1,11 @@ from functools import partial from collections import namedtuple from typing import Dict, Any -from .decider_definitions import get_all_deciders, get_system_prompt +from .decider_definitions import ( + get_runtime_deciders, + get_system_prompt, + _BASE_DECIDERS, +) from .config import get_decider_config, _get_dataset_name @@ -21,7 +25,7 @@ def _get_decider_options( Dict with option fields, or None if decider doesn't exist for probe's dataset """ try: - dataset_name = _get_dataset_name(probe_id, datasets) + _get_dataset_name(probe_id, datasets) except ValueError: return None @@ -29,16 +33,9 @@ def _get_decider_options( if not decider_cfg: return None - config_overrides = decider_cfg.get("config_overrides", {}) - dataset_overrides = decider_cfg.get("dataset_overrides", {}).get(dataset_name, {}) - metadata = { "llm_backbones": decider_cfg.get("llm_backbones", []), - "model_path_keys": decider_cfg.get("model_path_keys", []), - "max_alignment_attributes": config_overrides.get( - "max_alignment_attributes", - dataset_overrides.get("max_alignment_attributes", 0), - ), + "max_alignment_attributes": decider_cfg.get("max_alignment_attributes", 0), "config_path": decider_cfg.get("config_path"), "exists": True, } @@ -53,18 +50,97 @@ def _get_decider_options( "get_decider_options", "get_system_prompt", "get_all_deciders", + "add_edited_decider", + "add_deciders", ], ) -def create_decider_registry(config_paths, scenario_registry): +def _get_root_decider_name(decider_name: str) -> str: + """Extract the root decider name without any ' - edit N' suffix.""" + import re + + match = re.match(r"^(.+?) - edit \d+$", decider_name) + if match: + return _get_root_decider_name(match.group(1)) + return decider_name + + +def _find_matching_decider( + resolved_config: Dict[str, Any], + all_deciders: Dict[str, Any], +) -> str | None: + for name, entry in all_deciders.items(): + if entry.get("resolved_config") == resolved_config: + return name + return None + + +def _add_edited_decider( + base_decider_name: str, + resolved_config: Dict[str, Any], + llm_backbones: list, + all_deciders: Dict[str, Any], +) -> str: + """ + Add an edited decider to the registry. + + Args: + base_decider_name: Original decider name this was edited from + resolved_config: The edited resolved config + llm_backbones: Available LLM backbones for this decider + all_deciders: The mutable deciders dictionary (pre-bound via partial) + + Returns: + The new decider name "{root_decider_name} - edit {n}" + """ + existing = _find_matching_decider(resolved_config, all_deciders) + if existing: + return existing + + root_name = _get_root_decider_name(base_decider_name) + + edit_count = 1 + for name in all_deciders: + if name.startswith(f"{root_name} - edit "): + try: + n = int(name.split(" - edit ")[-1]) + edit_count = max(edit_count, n + 1) + except ValueError: + pass + + new_name = f"{root_name} - edit {edit_count}" + all_deciders[new_name] = { + "edited_config": True, + "resolved_config": resolved_config, + "llm_backbones": llm_backbones, + "max_alignment_attributes": 10, + } + return new_name + + +def create_decider_registry(config_paths, scenario_registry, experiment_deciders=None): """ Takes config paths and scenario_registry, returns a DeciderRegistry namedtuple with all_deciders and datasets pre-bound using partial application. + + Args: + config_paths: List of paths to runtime decider configs + scenario_registry: Registry for scenarios/probes + experiment_deciders: Optional dict of experiment deciders to merge """ - all_deciders = get_all_deciders(config_paths) + all_deciders = { + **_BASE_DECIDERS, + **(experiment_deciders or {}), + **get_runtime_deciders(config_paths), + } datasets = scenario_registry.get_datasets() + def add_deciders(new_deciders: Dict[str, Any]): + for name, entry in new_deciders.items(): + if name not in all_deciders: + all_deciders[name] = entry + return DeciderRegistry( get_decider_config=partial( get_decider_config, @@ -82,4 +158,9 @@ def create_decider_registry(config_paths, scenario_registry): datasets=datasets, ), get_all_deciders=lambda: all_deciders, + add_edited_decider=partial( + _add_edited_decider, + all_deciders=all_deciders, + ), + add_deciders=add_deciders, ) diff --git a/align_app/adm/experiment_config_loader.py b/align_app/adm/experiment_config_loader.py new file mode 100644 index 0000000..8d5fd99 --- /dev/null +++ b/align_app/adm/experiment_config_loader.py @@ -0,0 +1,17 @@ +"""Shared loader for experiment config files.""" + +from functools import lru_cache +from pathlib import Path +from typing import Dict, Any +import yaml + + +@lru_cache(maxsize=256) +def load_experiment_adm_config(experiment_path: Path) -> Dict[str, Any] | None: + """Load the adm config from experiment's .hydra/config.yaml.""" + config_path = experiment_path / ".hydra" / "config.yaml" + if not config_path.exists(): + return None + with open(config_path) as f: + config = yaml.safe_load(f) + return config.get("adm", config) diff --git a/align_app/adm/experiment_converters.py b/align_app/adm/experiment_converters.py new file mode 100644 index 0000000..a4679a2 --- /dev/null +++ b/align_app/adm/experiment_converters.py @@ -0,0 +1,151 @@ +"""Pure functions to convert experiment data to domain types.""" + +import copy +import hashlib +import json +import uuid +from pathlib import Path +from typing import List, Dict, Any, Optional + +from align_utils.models import ( + ExperimentItem, + ExperimentData, + ADMResult, + Decision, + ChoiceInfo, +) + +from .probe import Probe, get_probe_id +from .decider_definitions import LLM_BACKBONES +from .experiment_config_loader import load_experiment_adm_config +from .decider.types import DeciderParams +from .run_models import Run, RunDecision + + +def probes_from_experiment_items(items: List[ExperimentItem]) -> List[Probe]: + """Convert experiment items to probes, deduping by probe_id.""" + seen = set() + probes = [] + for item in items: + probe = Probe.from_input_output_item(item.item) + if probe.probe_id not in seen: + seen.add(probe.probe_id) + probes.append(probe) + return probes + + +def deciders_from_experiments( + experiments: List[ExperimentData], +) -> Dict[str, Dict[str, Any]]: + """Extract unique decider configs from experiments. + + Returns dict: {decider_name: decider_entry} + """ + seen_hashes: Dict[str, tuple] = {} + + for exp in experiments: + adm_config = load_experiment_adm_config(exp.experiment_path) + if adm_config is None: + continue + + normalized = _normalize_adm_config(adm_config) + config_hash = _hash_config(normalized) + + if config_hash not in seen_hashes: + exp_name = exp.experiment_path.parent.name + + if "structured_inference_engine" in adm_config: + experiment_llm = adm_config["structured_inference_engine"].get( + "model_name" + ) + llm_backbones = ( + [experiment_llm] + + [llm for llm in LLM_BACKBONES if llm != experiment_llm] + if experiment_llm + else list(LLM_BACKBONES) + ) + else: + llm_backbones = [] + + decider_entry = { + "experiment_path": str(exp.experiment_path), + "experiment_config": True, + "llm_backbones": llm_backbones, + "max_alignment_attributes": 10, + } + seen_hashes[config_hash] = (exp_name, decider_entry) + + return {name: entry for name, entry in seen_hashes.values()} + + +def _normalize_adm_config(config: Dict[str, Any]) -> Dict[str, Any]: + """Normalize config for comparison by stripping absolute paths to filenames.""" + normalized = copy.deepcopy(config) + _normalize_paths_recursive(normalized) + return normalized + + +def _normalize_paths_recursive(obj: Any) -> None: + """Recursively normalize path-like strings to just filenames.""" + if isinstance(obj, dict): + for key, value in obj.items(): + if isinstance(value, str) and "/" in value and value.endswith(".json"): + obj[key] = Path(value).name + else: + _normalize_paths_recursive(value) + elif isinstance(obj, list): + for i, item in enumerate(obj): + if isinstance(item, str) and "/" in item and item.endswith(".json"): + obj[i] = Path(item).name + else: + _normalize_paths_recursive(item) + + +def _hash_config(config: Dict[str, Any]) -> str: + """Create deterministic hash of config dict.""" + config_str = json.dumps(config, sort_keys=True) + return hashlib.sha256(config_str.encode()).hexdigest()[:16] + + +def run_from_experiment_item(item: ExperimentItem) -> Optional[Run]: + """Convert ExperimentItem to Run with decision populated.""" + if not item.item.output: + return None + + probe_id = get_probe_id(item.item) + + resolved_config = load_experiment_adm_config(item.experiment_path) or {} + decider_params = DeciderParams( + scenario_input=item.item.input, + alignment_target=item.config.alignment_target, + resolved_config=resolved_config, + ) + + output = item.item.output + decision = RunDecision( + adm_result=ADMResult( + decision=Decision( + unstructured=output.action.unstructured, + justification=output.action.justification or "", + ), + choice_info=item.item.choice_info or ChoiceInfo(), + ), + choice_index=output.choice, + ) + + decider_name = item.experiment_path.parent.name + + return Run( + id=str(uuid.uuid4()), + probe_id=probe_id, + decider_name=decider_name, + llm_backbone_name=item.config.adm.llm_backbone or "N/A", + system_prompt="", + decider_params=decider_params, + decision=decision, + ) + + +def runs_from_experiment_items(items: List[ExperimentItem]) -> List[Run]: + """Convert experiment items to runs, filtering out items without output.""" + return [run for item in items if (run := run_from_experiment_item(item))] diff --git a/align_app/adm/experiment_results_registry.py b/align_app/adm/experiment_results_registry.py new file mode 100644 index 0000000..13b9b35 --- /dev/null +++ b/align_app/adm/experiment_results_registry.py @@ -0,0 +1,32 @@ +from collections import namedtuple +from pathlib import Path +from typing import List +from align_utils.discovery import parse_experiments_directory +from align_utils.models import ExperimentItem, ExperimentData, get_experiment_items + + +ExperimentResultsRegistry = namedtuple( + "ExperimentResultsRegistry", + [ + "get_all_items", + "get_experiments", + ], +) + + +def create_experiment_results_registry( + experiments_path: Path, +) -> ExperimentResultsRegistry: + """ + Creates an ExperimentResultsRegistry with pre-computed experiment results. + Loads experiment directories containing input_output.json + hydra configs. + """ + experiments: List[ExperimentData] = parse_experiments_directory(experiments_path) + all_items: List[ExperimentItem] = [ + item for exp in experiments for item in get_experiment_items(exp) + ] + + return ExperimentResultsRegistry( + get_all_items=lambda: all_items, + get_experiments=lambda: experiments, + ) diff --git a/align_app/adm/probe.py b/align_app/adm/probe.py index ca5eab4..be0e2d2 100644 --- a/align_app/adm/probe.py +++ b/align_app/adm/probe.py @@ -3,6 +3,19 @@ from align_utils.models import InputOutputItem +def get_probe_id(item: InputOutputItem) -> str: + """Extract probe_id from InputOutputItem in format '{scenario_id}.{scene_id}'.""" + if not item.input or not item.input.full_state: + raise ValueError("InputOutputItem must have input and full_state") + + full_state = item.input.full_state + if "meta_info" not in full_state or "scene_id" not in full_state["meta_info"]: + raise ValueError("InputOutputItem missing required meta_info.scene_id") + + scene_id = full_state["meta_info"]["scene_id"] + return f"{item.input.scenario_id}.{scene_id}" + + class Probe(BaseModel): """ Wrapper around InputOutputItem that adds derived fields for convenient access. @@ -41,21 +54,12 @@ def from_input_output_item(cls, item: InputOutputItem) -> "Probe": Raises: ValueError: If required fields are missing from the input data """ - if not item.input or not item.input.full_state: - raise ValueError("InputOutputItem must have input and full_state") - - if ( - "meta_info" not in item.input.full_state - or "scene_id" not in item.input.full_state["meta_info"] - ): - raise ValueError("InputOutputItem missing required meta_info.scene_id") - - scene_id = item.input.full_state["meta_info"]["scene_id"] - probe_id = f"{item.input.scenario_id}.{scene_id}" + probe_id = get_probe_id(item) + full_state = item.input.full_state + assert full_state is not None + scene_id = full_state["meta_info"]["scene_id"] - display_state = None - if "unstructured" in item.input.full_state: - display_state = item.input.full_state["unstructured"] + display_state = full_state.get("unstructured") return cls( item=item, diff --git a/align_app/adm/probe_registry.py b/align_app/adm/probe_registry.py index e732d4d..2260e69 100644 --- a/align_app/adm/probe_registry.py +++ b/align_app/adm/probe_registry.py @@ -46,6 +46,7 @@ def _truncate_probe(probe: Probe) -> Probe: "get_datasets", "get_attributes", "add_edited_probe", + "add_probes", ], ) @@ -108,12 +109,41 @@ def get_attributes(probe_id): dataset_info = datasets[dataset_name] return dataset_info.get("attributes", {}) + def _probes_content_equal( + probe: Probe, edited_text: str, edited_choices: List[Dict[str, Any]] + ) -> bool: + if probe.display_state != edited_text: + return False + probe_choices = probe.choices or [] + if len(probe_choices) != len(edited_choices): + return False + for pc, ec in zip(probe_choices, edited_choices): + if pc.get("unstructured") != ec.get("unstructured"): + return False + return True + + def _find_matching_probe( + scenario_id: str, edited_text: str, edited_choices: List[Dict[str, Any]] + ) -> Probe | None: + for probe in probes.values(): + if probe.scenario_id != scenario_id: + continue + if _probes_content_equal(probe, edited_text, edited_choices): + return probe + return None + def add_edited_probe( base_probe_id: str, edited_text: str, edited_choices: List[Dict[str, Any]] ) -> Probe: """Create new probe with edited content and -edit-N suffix.""" base_probe = get_probe(base_probe_id) + existing = _find_matching_probe( + base_probe.scenario_id, edited_text, edited_choices + ) + if existing: + return existing + base_scene = base_probe.scene_id.split(" edit ")[0] edit_num = 1 for existing_id in probes: @@ -146,6 +176,13 @@ def add_edited_probe( return new_probe + def add_probes(new_probes: List[Probe]): + """Add probes to registry, skipping duplicates.""" + for probe in new_probes: + if probe.probe_id not in probes: + probes[probe.probe_id] = probe + datasets["phase2"]["probes"][probe.probe_id] = probe + return ProbeRegistry( get_probes=lambda: probes, get_dataset_name=get_dataset_name, @@ -153,4 +190,5 @@ def add_edited_probe( get_datasets=lambda: datasets, get_attributes=get_attributes, add_edited_probe=add_edited_probe, + add_probes=add_probes, ) diff --git a/align_app/app/run_models.py b/align_app/adm/run_models.py similarity index 77% rename from align_app/app/run_models.py rename to align_app/adm/run_models.py index e102d36..289ebf5 100644 --- a/align_app/app/run_models.py +++ b/align_app/adm/run_models.py @@ -2,7 +2,7 @@ from pydantic import BaseModel import hashlib import json -from ..adm.decider.types import DeciderParams +from .decider.types import DeciderParams from align_utils.models import ADMResult @@ -12,13 +12,23 @@ def hash_run_params( llm_backbone_name: str, decider_params: DeciderParams, ) -> str: + alignment_target = decider_params.alignment_target + kdma_values = [kv.model_dump() for kv in alignment_target.kdma_values] + + # Exclude alignment_target from resolved_config since it's already in kdma_values + resolved_config = decider_params.resolved_config or {} + config_for_hash = { + k: v for k, v in resolved_config.items() if k != "alignment_target" + } + cache_key_data = { "probe_id": probe_id, "decider": decider_name, "llm_backbone": llm_backbone_name, - "alignment_target": decider_params.alignment_target.model_dump(), + "kdma_values": kdma_values, "state": decider_params.scenario_input.state, "choices": decider_params.scenario_input.choices, + "resolved_config": config_for_hash, } json_str = json.dumps(cache_key_data, sort_keys=True) return hashlib.md5(json_str.encode()).hexdigest() diff --git a/align_app/adm/types.py b/align_app/adm/types.py index e1d4140..faee32c 100644 --- a/align_app/adm/types.py +++ b/align_app/adm/types.py @@ -54,12 +54,20 @@ class DeciderContext(Prompt): resolved_config: dict +def _alignment_target_id_from_attributes(attributes: List[Attribute]) -> str: + """Generate alignment target ID from attributes (e.g., 'affiliation-0.5_merit-0.3').""" + if not attributes: + return "unknown" + parts = [f"{a['type']}-{a['score']}" for a in attributes] + return "_".join(sorted(parts)) + + def attributes_to_alignment_target( attributes: List[Attribute], ) -> AlignmentTarget: """Create AlignmentTarget Pydantic model from attributes.""" return AlignmentTarget( - id="ad_hoc", + id=_alignment_target_id_from_attributes(attributes), kdma_values=[ KDMAValue( kdma=a["type"], diff --git a/align_app/app/core.py b/align_app/app/core.py index 9213ddf..171c5ee 100644 --- a/align_app/app/core.py +++ b/align_app/app/core.py @@ -1,11 +1,13 @@ +from pathlib import Path from trame.app import get_server from trame.decorators import TrameApp, controller from . import ui from .search import SearchController -from .runs_registry import create_runs_registry +from .runs_registry import RunsRegistry from .runs_state_adapter import RunsStateAdapter from ..adm.decider_registry import create_decider_registry from ..adm.probe_registry import create_probe_registry +from .import_experiments import import_experiments @TrameApp() @@ -28,24 +30,49 @@ def __init__(self, server=None): help="Paths to scenarios JSON files or directories of JSON files (space-separated)", ) + self.server.cli.add_argument( + "--experiments", + help="Path to directory containing pre-computed experiment results", + ) + args, _ = self.server.cli.parse_known_args() - self._probe_registry = create_probe_registry(args.scenarios) + # Skip default probes if either --scenarios or --experiments is provided + scenarios_paths = args.scenarios + if args.experiments and scenarios_paths is None: + scenarios_paths = [] + + self._probe_registry = create_probe_registry(scenarios_paths) + + experiment_result = None + if args.experiments: + experiment_result = import_experiments(Path(args.experiments)) + self._probe_registry.add_probes(experiment_result.probes) + self._decider_registry = create_decider_registry( - args.deciders or [], self._probe_registry + args.deciders or [], + self._probe_registry, + experiment_deciders=experiment_result.deciders if experiment_result else {}, ) - self._runs_registry = create_runs_registry( + self._runs_registry = RunsRegistry( self._probe_registry, self._decider_registry, ) - self._search_controller = SearchController(self.server, self._probe_registry) + + if experiment_result: + self._runs_registry.add_experiment_items(experiment_result.items) + self._runsController = RunsStateAdapter( self.server, self._probe_registry, self._decider_registry, self._runs_registry, ) - self._search_controller.set_runs_state_adapter(self._runsController) + self._search_controller = SearchController( + self.server, + self._probe_registry, + on_search_select=self._handle_search_select, + ) if self.server.hot_reload: self.server.controller.on_server_reload.add(self._build_ui) @@ -53,6 +80,10 @@ def __init__(self, server=None): self._build_ui() self.reset_state() + def _handle_search_select(self, run_id: str, scenario_id: str, scene_id: str): + new_run_id = self._runsController.update_run_scenario(run_id, scenario_id) + self._runsController.update_run_scene(new_run_id, scene_id) + @controller.set("reset_state") def reset_state(self): self._runsController.reset_state() diff --git a/align_app/app/export_experiments.py b/align_app/app/export_experiments.py new file mode 100644 index 0000000..bb65667 --- /dev/null +++ b/align_app/app/export_experiments.py @@ -0,0 +1,165 @@ +"""Export runs as Pydantic Experiment structures in ZIP format.""" + +import io +import json +import zipfile +from typing import Any, Dict, List, Tuple + +import yaml +from align_utils.models import ( + Action, + ChoiceInfo, + InputData, + InputOutputItem, + Output, +) + + +def _extract_choice_index(decision: Dict) -> int: + """Extract choice index from decision unstructured text (A. -> 0, B. -> 1).""" + if "unstructured" in decision: + text = decision["unstructured"] + if text and len(text) > 0: + first_char = text[0].upper() + if first_char.isalpha() and first_char >= "A": + return ord(first_char) - ord("A") + return 0 + + +def run_dict_to_input_output_item( + run_dict: Dict[str, Any], alignment_target_id: str +) -> InputOutputItem: + """Convert a run state dict to InputOutputItem Pydantic model.""" + prompt = run_dict["prompt"] + decision = run_dict.get("decision") + + input_data = InputData( + scenario_id=prompt["probe"]["scenario_id"], + alignment_target_id=alignment_target_id, + full_state=prompt["probe"]["full_state"], + state=prompt["probe"].get("state") + or prompt["probe"]["full_state"].get("unstructured"), + choices=prompt["probe"]["choices"], + ) + + output = None + choice_info = None + if decision: + choice_idx = _extract_choice_index(decision) + choices = prompt["probe"]["choices"] + + action_dict = choices[choice_idx] if choice_idx < len(choices) else choices[0] + action = Action( + action_id=action_dict.get("action_id", f"choice_{choice_idx}"), + action_type=action_dict.get("action_type", "CHOICE"), + unstructured=action_dict.get("unstructured", ""), + justification=decision.get("justification", ""), + character_id=action_dict.get("character_id"), + intent_action=action_dict.get("intent_action"), + kdma_association=action_dict.get("kdma_association"), + ) + output = Output(choice=choice_idx, action=action) + + if "choice_info" in decision and decision["choice_info"]: + choice_info = ChoiceInfo(**decision["choice_info"]) + + return InputOutputItem( + input=input_data, + output=output, + choice_info=choice_info, + label=None, + ) + + +def _alignment_target_id_from_kdma_values(alignment_target: Dict[str, Any]) -> str: + """Generate alignment target ID from KDMA values (e.g., 'affiliation-0.5_merit-0.3').""" + kdma_values = alignment_target.get("kdma_values", []) + if not kdma_values: + return "unknown" + parts = [f"{kv.get('kdma')}-{kv.get('value', 0.0)}" for kv in kdma_values] + return "_".join(sorted(parts)) + + +def _get_alignment_target_id(run_dict: Dict[str, Any]) -> str: + """Get alignment target ID, generating one from KDMA values if 'ad_hoc' or missing.""" + alignment_target = run_dict["prompt"].get("alignment_target") or {} + target_id = alignment_target.get("id", "") + if not target_id or target_id == "ad_hoc": + return _alignment_target_id_from_kdma_values(alignment_target) + return target_id + + +def _group_runs_by_experiment( + runs_dict: Dict[str, Dict[str, Any]], +) -> Dict[Tuple[str, str], List[Dict[str, Any]]]: + """Group runs by (decider_name, alignment_target_id).""" + groups: Dict[Tuple[str, str], List[Dict[str, Any]]] = {} + + for run_dict in runs_dict.values(): + if not run_dict.get("decision"): + continue + + decider_name = run_dict["prompt"]["decider"]["name"] + alignment_target_id = _get_alignment_target_id(run_dict) + key = (decider_name, alignment_target_id) + + if key not in groups: + groups[key] = [] + groups[key].append(run_dict) + + return groups + + +def _build_experiment_config( + run_dict: Dict[str, Any], + decider_name: str, + alignment_target_id: str, +) -> Dict[str, Any]: + """Build experiment config matching align_utils ExperimentConfig format.""" + resolved_config = run_dict["prompt"].get("resolved_config") or {} + alignment_target = run_dict["prompt"]["alignment_target"] + + kdma_values = [ + {"kdma": kv.get("kdma"), "value": kv.get("value", 0.0), "kdes": kv.get("kdes")} + for kv in alignment_target.get("kdma_values", []) + ] + + return { + "adm": resolved_config, + "alignment_target": {"id": alignment_target_id, "kdma_values": kdma_values}, + } + + +def export_runs_to_zip(runs_dict: Dict[str, Dict[str, Any]]) -> bytes: + """Export runs to ZIP file as bytes for browser download.""" + groups = _group_runs_by_experiment(runs_dict) + + if not groups: + return b"" + + zip_buffer = io.BytesIO() + + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: + for (decider_name, alignment_target_id), run_dicts in groups.items(): + items = [ + run_dict_to_input_output_item(rd, alignment_target_id) + for rd in run_dicts + ] + + items_json = json.dumps( + [item.model_dump(exclude_none=True) for item in items], + indent=2, + ) + + config = _build_experiment_config( + run_dicts[0], decider_name, alignment_target_id + ) + config_yaml = yaml.dump(config, default_flow_style=False, sort_keys=False) + + base_path = f"{decider_name}/{alignment_target_id}" + + zf.writestr(f"{base_path}/input_output.json", items_json) + zf.writestr(f"{base_path}/.hydra/config.yaml", config_yaml) + + zip_buffer.seek(0) + return zip_buffer.read() diff --git a/align_app/app/import_experiments.py b/align_app/app/import_experiments.py new file mode 100644 index 0000000..6ff4e77 --- /dev/null +++ b/align_app/app/import_experiments.py @@ -0,0 +1,127 @@ +"""Import experiments from ZIP files and directories.""" + +import io +import tempfile +import uuid +import zipfile +from dataclasses import dataclass +from pathlib import Path +from typing import List, Dict, Optional + +from align_utils.discovery import parse_experiments_directory +from align_utils.models import ( + ExperimentData, + ExperimentItem, + get_experiment_items, + ADMResult, + Decision, + ChoiceInfo, +) + +from ..adm.experiment_converters import ( + deciders_from_experiments, + probes_from_experiment_items, +) +from ..adm.experiment_config_loader import load_experiment_adm_config +from ..adm.probe import Probe, get_probe_id +from ..adm.decider.types import DeciderParams +from ..adm.run_models import Run, RunDecision +from .runs_presentation import compute_experiment_item_cache_key + + +@dataclass +class StoredExperimentItem: + """ExperimentItem with pre-loaded config for use after temp dir cleanup.""" + + item: ExperimentItem + resolved_config: Dict + cache_key: str + + +@dataclass +class ExperimentImportResult: + """Result of importing experiments.""" + + probes: List[Probe] + deciders: dict + items: Dict[str, StoredExperimentItem] + + +def import_experiments(experiments_path: Path) -> ExperimentImportResult: + """Import experiments from a directory path. + + Returns ExperimentImportResult with probes, deciders, and items keyed by cache_key. + """ + print(f"Loading experiments from {experiments_path}...") + experiments: List[ExperimentData] = parse_experiments_directory(experiments_path) + all_items: List[ExperimentItem] = [ + item for exp in experiments for item in get_experiment_items(exp) + ] + + probes = probes_from_experiment_items(all_items) + deciders = deciders_from_experiments(experiments) + + items: Dict[str, StoredExperimentItem] = {} + for item in all_items: + resolved_config = load_experiment_adm_config(item.experiment_path) or {} + cache_key = compute_experiment_item_cache_key(item, resolved_config) + items[cache_key] = StoredExperimentItem(item, resolved_config, cache_key) + + print(f"Loaded {len(items)} experiment items from {len(experiments)} experiments") + return ExperimentImportResult(probes, deciders, items) + + +def import_experiments_from_zip(zip_bytes: bytes) -> ExperimentImportResult: + """Extract and parse experiments from a ZIP file. + + Returns ExperimentImportResult with probes, deciders, and items keyed by cache_key. + """ + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf: + zf.extractall(temp_path) + + return import_experiments(temp_path) + + +def run_from_stored_experiment_item(stored: StoredExperimentItem) -> Optional[Run]: + """Convert StoredExperimentItem to Run with decision populated. + + Uses pre-loaded resolved_config instead of loading from (possibly deleted) path. + """ + item = stored.item + if not item.item.output: + return None + + probe_id = get_probe_id(item.item) + + decider_params = DeciderParams( + scenario_input=item.item.input, + alignment_target=item.config.alignment_target, + resolved_config=stored.resolved_config, + ) + + output = item.item.output + decision = RunDecision( + adm_result=ADMResult( + decision=Decision( + unstructured=output.action.unstructured, + justification=output.action.justification or "", + ), + choice_info=item.item.choice_info or ChoiceInfo(), + ), + choice_index=output.choice, + ) + + decider_name = item.experiment_path.parent.name + + return Run( + id=str(uuid.uuid4()), + probe_id=probe_id, + decider_name=decider_name, + llm_backbone_name=item.config.adm.llm_backbone or "N/A", + system_prompt="", + decider_params=decider_params, + decision=decision, + ) diff --git a/align_app/app/prompt_logic.py b/align_app/app/prompt_logic.py deleted file mode 100644 index f76ec2b..0000000 --- a/align_app/app/prompt_logic.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Pure business logic and validation functions for prompt handling.""" - -from typing import Dict, List, Any, cast -import copy -from omegaconf import OmegaConf -from ..adm.types import Attribute, DeciderParams, Prompt, attributes_to_alignment_target -from ..adm.probe import Probe -from ..adm.config import get_decider_config - - -def create_default_choice(index: int, text: str) -> Dict[str, Any]: - """Create a new choice with required fields for the alignment system.""" - return { - "action_id": f"action-{index}", - "unstructured": text, - "action_type": "APPLY_TREATMENT", - "intent_action": True, - "parameters": {}, - "justification": None, - } - - -def compute_possible_attributes( - all_attrs: Dict, used_attrs: set, descriptions: Dict -) -> List[Dict]: - """Compute available attributes not currently in use.""" - return [ - { - "value": key, - **details, - "description": descriptions.get(key, {}).get( - "description", f"No description available for {key}" - ), - } - for key, details in all_attrs.items() - if key not in used_attrs - ] - - -def filter_valid_attributes( - attributes: List[Dict], valid_attributes: Dict -) -> List[Dict]: - """Filter attributes to only include valid ones for the dataset.""" - return [ - attr - for attr in attributes - if attr["value"] in valid_attributes - and attr.get("possible_scores") - == valid_attributes[attr["value"]].get("possible_scores") - ] - - -def select_initial_decider(deciders: List[Dict], current: str = "") -> str: - """Select the initial decider based on current selection and available options.""" - if not deciders: - return "" - - valid_values = [dm["value"] for dm in deciders] - if not current or current not in valid_values: - return deciders[0]["value"] - return current - - -def build_choices_from_edited( - edited_choices: List[str], original_choices: List[Dict] -) -> List[Dict]: - """Build new choices array with edited text.""" - new_choices = [] - for i, choice_text in enumerate(edited_choices): - if i < len(original_choices): - choice = copy.deepcopy(original_choices[i]) - choice["unstructured"] = choice_text - else: - choice = create_default_choice(i, choice_text) - new_choices.append(choice) - return new_choices - - -def get_max_alignment_attributes(decider_configs: Dict) -> int: - """Extract max alignment attributes from decider configs.""" - if not decider_configs: - return 0 - return decider_configs.get("max_alignment_attributes", 0) - - -def get_llm_backbones_from_config(decider_configs: Dict) -> List[str]: - """Extract LLM backbones from decider config.""" - if decider_configs and "llm_backbones" in decider_configs: - return decider_configs["llm_backbones"] - return ["N/A"] - - -def find_probe_by_base_and_scene( - probes: Dict[str, Probe], base_id: str, scene_id: str -) -> str: - """Find the full probe_id given base and scene IDs.""" - matches = [ - probe_id - for probe_id, probe in probes.items() - if probe.scenario_id == base_id and probe.scene_id == scene_id - ] - return matches[0] if matches else "" - - -def create_prompt_base( - probe: Probe, llm_backbone: str, decider: str, attributes: List[Attribute] -) -> dict: - """Build base prompt structure from components. - - Returns a dict with probe as Probe model (internal representation). - """ - return { - "decider_params": DeciderParams(llm_backbone=llm_backbone, decider=decider), - "alignment_target": attributes_to_alignment_target(attributes), - "probe": probe, - } - - -def build_prompt_context( - probe_id: str, - llm_backbone: str, - decider: str, - attributes: List[Dict], - system_prompt: str, - edited_text: str, - edited_choices: List[str], - decider_registry, - probe_registry, -) -> Dict: - mapped_attributes: List[Attribute] = [ - Attribute(type=a["value"], score=a["score"]) for a in attributes - ] - - probe = probe_registry.get_probe(probe_id) - - prompt_data = create_prompt_base(probe, llm_backbone, decider, mapped_attributes) - - resolved_config = decider_registry.get_decider_config( - probe.probe_id, - decider=decider, - llm_backbone=llm_backbone, - ) - - original_choices = cast(List[Dict], probe.choices or []) - edited_choices_list = build_choices_from_edited(edited_choices, original_choices) - - updated_full_state = copy.deepcopy(probe.full_state) or {} - updated_full_state["unstructured"] = edited_text - - updated_probe = probe.model_copy( - update={ - "display_state": edited_text, - } - ) - - updated_probe.item.input.full_state = updated_full_state - updated_probe.item.input.choices = edited_choices_list - - return { - **prompt_data, - "system_prompt": system_prompt, - "resolved_config": resolved_config, - "probe": updated_probe, - "all_deciders": decider_registry.get_all_deciders(), - "datasets": probe_registry.get_datasets(), - } - - -def get_alignment_descriptions_map(prompt: Prompt) -> dict: - """Get attribute descriptions for alignment targets from ADM config.""" - probe: Probe = prompt["probe"] - probe_id = probe.probe_id - decider = prompt["decider_params"]["decider"] - all_deciders = prompt["all_deciders"] - datasets = prompt["datasets"] - - config = get_decider_config(probe_id, all_deciders, datasets, decider) - if not config: - return {} - - config.pop("instance", None) - config.pop("step_definitions", None) - - attributes_resolved = OmegaConf.to_container( - OmegaConf.create({"adm": config}), - resolve=True, - ) - - if not isinstance(attributes_resolved, dict): - return {} - - adm_section = attributes_resolved.get("adm", {}) - if not isinstance(adm_section, dict): - return {} - - attribute_map = adm_section.get("attribute_definitions", {}) - - return attribute_map diff --git a/align_app/app/runs_core.py b/align_app/app/runs_core.py index 7468751..180c5d3 100644 --- a/align_app/app/runs_core.py +++ b/align_app/app/runs_core.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, replace from typing import Dict, Optional, List -from .run_models import Run, RunDecision +from ..adm.run_models import Run, RunDecision from ..adm.decider import get_decision @@ -22,6 +22,36 @@ def add_run(data: Runs, run: Run) -> Runs: return new_data +def add_runs_bulk(data: Runs, runs: List[Run]) -> Runs: + """Add multiple runs efficiently in a single operation.""" + new_runs = {**data.runs} + new_cache = {**data.decision_cache} + + for run in runs: + new_runs[run.id] = run + if run.decision: + cache_key = run.compute_cache_key() + new_cache[cache_key] = run.decision + + return Runs(runs=new_runs, decision_cache=new_cache) + + +def populate_cache_bulk(data: Runs, runs: List[Run]) -> Runs: + """Populate decision cache from runs without adding to runs dict. + + Use for pre-computed experiment results that should populate cache + but not appear in UI. + """ + new_cache = {**data.decision_cache} + + for run in runs: + if run.decision: + cache_key = run.compute_cache_key() + new_cache[cache_key] = run.decision + + return replace(data, decision_cache=new_cache) + + def remove_run(data: Runs, run_id: str) -> Runs: runs = {rid: run for rid, run in data.runs.items() if rid != run_id} return replace(data, runs=runs) @@ -84,5 +114,9 @@ def init_runs() -> Runs: return Runs.empty() -def clear_runs(_: Runs) -> Runs: +def clear_runs(data: Runs) -> Runs: + return replace(data, runs={}) + + +def clear_all(data: Runs) -> Runs: return Runs.empty() diff --git a/align_app/app/runs_edit_logic.py b/align_app/app/runs_edit_logic.py index aa8e47a..c1317c1 100644 --- a/align_app/app/runs_edit_logic.py +++ b/align_app/app/runs_edit_logic.py @@ -1,19 +1,39 @@ """Update runs with new scenes and scenarios.""" from typing import Optional, Dict, Any -from .run_models import Run +from ..adm.run_models import Run from ..adm.probe import Probe import copy -from .prompt_logic import ( - create_default_choice, - find_probe_by_base_and_scene, +from .runs_presentation import ( + get_scenes_for_base_scenario, get_llm_backbones_from_config, get_max_alignment_attributes, ) -from .runs_presentation import get_scenes_for_base_scenario from align_utils.models import AlignmentTarget, KDMAValue +def create_default_choice(index: int, text: str) -> Dict[str, Any]: + return { + "action_id": f"action-{index}", + "unstructured": text, + "action_type": "APPLY_TREATMENT", + "intent_action": True, + "parameters": {}, + "justification": None, + } + + +def find_probe_by_scenario_and_scene( + probes: Dict[str, Probe], scenario_id: str, scene_id: str +) -> str: + matches = [ + probe_id + for probe_id, probe in probes.items() + if probe.scenario_id == scenario_id and probe.scene_id == scene_id + ] + return matches[0] if matches else "" + + def get_first_scene_for_scenario(probes: Dict[str, Probe], scenario_id: str) -> str: scene_items = get_scenes_for_base_scenario(probes, scenario_id) return scene_items[0]["value"] if scene_items else "" @@ -49,7 +69,7 @@ def prepare_scene_update( base_scenario_id = current_probe.scenario_id probes = probe_registry.get_probes() - new_probe_id = find_probe_by_base_and_scene(probes, base_scenario_id, scene_id) + new_probe_id = find_probe_by_scenario_and_scene(probes, base_scenario_id, scene_id) new_probe = probe_registry.get_probe(new_probe_id) if not new_probe: @@ -73,7 +93,7 @@ def prepare_scenario_update( if not first_scene_id: return None - new_probe_id = find_probe_by_base_and_scene(probes, scenario_id, first_scene_id) + new_probe_id = find_probe_by_scenario_and_scene(probes, scenario_id, first_scene_id) new_probe = probe_registry.get_probe(new_probe_id) if not new_probe: diff --git a/align_app/app/runs_presentation.py b/align_app/app/runs_presentation.py index 09564ba..c78e667 100644 --- a/align_app/app/runs_presentation.py +++ b/align_app/app/runs_presentation.py @@ -1,17 +1,109 @@ """Transform domain models to UI state dictionaries and export formats.""" from typing import Dict, Any, List -from .run_models import Run, RunDecision, hash_run_params +from ..adm.run_models import Run, RunDecision, hash_run_params from .ui import prep_decision_for_state -from .prompt_logic import ( - get_llm_backbones_from_config, - get_max_alignment_attributes, - compute_possible_attributes, -) -from ..adm.probe import Probe +from ..adm.probe import Probe, get_probe_id +from ..adm.decider.types import DeciderParams from ..utils.utils import readable +from align_utils.models import ExperimentItem import json import copy +import yaml + + +def compute_experiment_item_cache_key( + item: ExperimentItem, + resolved_config: Dict[str, Any], +) -> str: + """Compute cache_key for an experiment item (same as Run.compute_cache_key). + + Takes resolved_config as param since it must be loaded while paths are valid. + """ + probe_id = get_probe_id(item.item) + decider_name = item.experiment_path.parent.name + llm_backbone = item.config.adm.llm_backbone or "N/A" + + decider_params = DeciderParams( + scenario_input=item.item.input, + alignment_target=item.config.alignment_target, + resolved_config=resolved_config, + ) + + return hash_run_params(probe_id, decider_name, llm_backbone, decider_params) + + +def experiment_item_to_table_row( + item: ExperimentItem, cache_key: str +) -> Dict[str, Any]: + """Convert ExperimentItem to table row format.""" + scene_id = "" + display_state = "" + if item.item.input.full_state: + scene_id = item.item.input.full_state.get("meta_info", {}).get("scene_id", "") + display_state = item.item.input.full_state.get("unstructured", "") or "" + + decision_text = "" + if item.item.output: + choice_letter = chr(item.item.output.choice + ord("A")) + decision_text = f"{choice_letter}. {item.item.output.action.unstructured}" + + kdma_values = item.config.alignment_target.kdma_values + alignment_summary = ( + ", ".join(f"{kv.kdma} {kv.value}" for kv in kdma_values) + if kdma_values + else "None" + ) + + choices = item.item.input.choices or [] + choice_texts = " ".join(c.get("unstructured", "") for c in choices) + searchable_text = f"{display_state} {choice_texts}" + + return { + "id": cache_key, + "scenario_id": item.item.input.scenario_id, + "scene_id": scene_id, + "probe_text": display_state, + "decider_name": item.config.adm.name, + "llm_backbone_name": item.config.adm.llm_backbone or "N/A", + "alignment_summary": alignment_summary, + "decision_text": decision_text, + "searchable_text": searchable_text, + } + + +def get_max_alignment_attributes(decider_configs: Dict) -> int: + if not decider_configs: + return 0 + return decider_configs.get("max_alignment_attributes", 0) + + +def get_llm_backbones_from_config(decider_configs: Dict) -> List[str]: + if decider_configs and "llm_backbones" in decider_configs: + return decider_configs["llm_backbones"] + return ["N/A"] + + +def compute_possible_attributes( + all_attrs: Dict, used_attrs: set, descriptions: Dict +) -> List[Dict]: + return [ + { + "value": key, + **details, + "description": descriptions.get(key, {}).get( + "description", f"No description available for {key}" + ), + } + for key, details in all_attrs.items() + if key not in used_attrs + ] + + +def resolved_config_to_yaml(resolved_config: Dict[str, Any] | None) -> str: + if not resolved_config: + return "" + return yaml.dump(resolved_config, default_flow_style=False, sort_keys=False) def extract_base_scenarios(probes: Dict[str, Probe]) -> List[Dict]: @@ -62,30 +154,33 @@ def kdma_values_to_alignment_attributes( return result -def compute_possible_alignment_attributes_for_run( +def _get_attribute_descriptions( run: Run, probe_registry, decider_registry -) -> List[Dict[str, Any]]: - """Compute available alignment attributes not currently in use for a run.""" - all_attrs = probe_registry.get_attributes(run.probe_id) - used_kdmas = {kv.kdma for kv in run.decider_params.alignment_target.kdma_values} - - descriptions = {} +) -> Dict[str, Any]: + """Extract attribute_definitions from decider config.""" all_deciders = decider_registry.get_all_deciders() datasets = probe_registry.get_datasets() from ..adm.config import get_decider_config as get_config config = get_config(run.probe_id, all_deciders, datasets, run.decider_name) - if config: - from omegaconf import OmegaConf + if not config: + return {} + + from omegaconf import OmegaConf + + config.pop("instance", None) + config.pop("step_definitions", None) + resolved = OmegaConf.to_container(OmegaConf.create({"adm": config}), resolve=True) + if isinstance(resolved, dict): + return resolved.get("adm", {}).get("attribute_definitions", {}) + return {} - config.pop("instance", None) - config.pop("step_definitions", None) - resolved = OmegaConf.to_container( - OmegaConf.create({"adm": config}), resolve=True - ) - if isinstance(resolved, dict): - descriptions = resolved.get("adm", {}).get("attribute_definitions", {}) +def _compute_possible_alignment_attributes( + run: Run, all_attrs: Dict, descriptions: Dict +) -> List[Dict[str, Any]]: + """Compute available alignment attributes not currently in use.""" + used_kdmas = {kv.kdma for kv in run.decider_params.alignment_target.kdma_values} possible = compute_possible_attributes(all_attrs, used_kdmas, descriptions) return [ { @@ -166,28 +261,15 @@ def run_to_state_dict( if probe_registry and decider_registry: all_attrs = probe_registry.get_attributes(run.probe_id) - datasets = probe_registry.get_datasets() - all_deciders = decider_registry.get_all_deciders() - - from ..adm.config import get_decider_config as get_config - from omegaconf import OmegaConf - - descriptions = {} - config = get_config(run.probe_id, all_deciders, datasets, run.decider_name) - if config: - config.pop("instance", None) - config.pop("step_definitions", None) - resolved = OmegaConf.to_container( - OmegaConf.create({"adm": config}), resolve=True - ) - if isinstance(resolved, dict): - descriptions = resolved.get("adm", {}).get("attribute_definitions", {}) + descriptions = _get_attribute_descriptions( + run, probe_registry, decider_registry + ) alignment_attributes = kdma_values_to_alignment_attributes( run.decider_params.alignment_target.kdma_values, all_attrs, descriptions ) - possible_alignment_attributes = compute_possible_alignment_attributes_for_run( - run, probe_registry, decider_registry + possible_alignment_attributes = _compute_possible_alignment_attributes( + run, all_attrs, descriptions ) cache_key = hash_run_params( @@ -216,6 +298,9 @@ def run_to_state_dict( }, "system_prompt": system_prompt, "resolved_config": run.decider_params.resolved_config, + "resolved_config_yaml": resolved_config_to_yaml( + run.decider_params.resolved_config + ), "decider": {"name": run.decider_name}, "llm_backbone": run.llm_backbone_name, }, @@ -225,6 +310,36 @@ def run_to_state_dict( return result +def run_to_table_row(run_dict: Dict[str, Any]) -> Dict[str, Any]: + prompt = run_dict.get("prompt", {}) + probe = prompt.get("probe", {}) + decision = run_dict.get("decision") + + alignment_attrs = run_dict.get("alignment_attributes", []) + alignment_summary = ( + ", ".join(f"{a['title']} {a['score']}" for a in alignment_attrs) + if alignment_attrs + else "None" + ) + + display_state = probe.get("display_state", "") or "" + choices = probe.get("choices", []) or [] + choice_texts = " ".join(c.get("unstructured", "") for c in choices) + searchable_text = f"{display_state} {choice_texts}" + + return { + "id": run_dict["cache_key"], + "scenario_id": probe.get("scenario_id", ""), + "scene_id": probe.get("scene_id", ""), + "probe_text": display_state, + "decider_name": prompt.get("decider_params", {}).get("decider", ""), + "llm_backbone_name": prompt.get("decider_params", {}).get("llm_backbone", ""), + "alignment_summary": alignment_summary, + "decision_text": decision.get("unstructured", "") if decision else "", + "searchable_text": searchable_text, + } + + def export_runs_to_json(runs_dict: Dict[str, Dict[str, Any]]) -> str: exported_runs = [] diff --git a/align_app/app/runs_registry.py b/align_app/app/runs_registry.py index 146aa2d..9652fc6 100644 --- a/align_app/app/runs_registry.py +++ b/align_app/app/runs_registry.py @@ -1,73 +1,39 @@ """Service layer managing run state and coordinating domain operations.""" -from collections import namedtuple from typing import Optional, Dict, List, Any, Callable -from .run_models import Run +from ..adm.run_models import Run from . import runs_core from . import runs_edit_logic from ..utils.utils import get_id +from .import_experiments import StoredExperimentItem, run_from_stored_experiment_item -RunsRegistry = namedtuple( - "RunsRegistry", - [ - "add_run", - "execute_decision", - "execute_run_decision", - "create_and_execute_run", - "get_run", - "get_all_runs", - "clear_runs", - "update_run_scene", - "update_run_scenario", - "update_run_decider", - "update_run_llm_backbone", - "add_run_alignment_attribute", - "update_run_alignment_attribute_value", - "update_run_alignment_attribute_score", - "delete_run_alignment_attribute", - "update_run_probe_text", - "update_run_choice_text", - "add_run_choice", - "delete_run_choice", - ], -) - - -def create_runs_registry(probe_registry, decider_registry): - """Create runs service with methods for managing run lifecycle.""" - data = runs_core.init_runs() +class RunsRegistry: + def __init__(self, probe_registry, decider_registry): + self._probe_registry = probe_registry + self._decider_registry = decider_registry + self._runs = runs_core.init_runs() + self._experiment_items: Dict[str, StoredExperimentItem] = {} def _create_update_method( + self, prepare_fn: Callable[..., Optional[Run]], ) -> Callable[[str, Any], Optional[Run]]: - """Factory that generates registry update methods. - - Args: - prepare_fn: Orchestration helper that prepares updated run. - Signature: (run, value, *, probe_registry, decider_registry) -> Optional[Run] - - Returns: - Update method with signature: (run_id, value) -> Optional[Run] - """ - def update_method(run_id: str, value: Any) -> Optional[Run]: - nonlocal data - - run = runs_core.get_run(data, run_id) + run = runs_core.get_run(self._runs, run_id) if not run: return None updated_run = prepare_fn( run, value, - probe_registry=probe_registry, - decider_registry=decider_registry, + probe_registry=self._probe_registry, + decider_registry=self._decider_registry, ) if not updated_run: return None - system_prompt = decider_registry.get_system_prompt( + system_prompt = self._decider_registry.get_system_prompt( decider=updated_run.decider_name, alignment_target=updated_run.decider_params.alignment_target, probe_id=updated_run.probe_id, @@ -81,134 +47,159 @@ def update_method(run_id: str, value: Any) -> Optional[Run]: "system_prompt": system_prompt, } ) - new_run = runs_core.apply_cached_decision(data, new_run) + new_run = runs_core.apply_cached_decision(self._runs, new_run) if run.decision is None: - data = runs_core.remove_run(data, run_id) - data = runs_core.add_run(data, new_run) + self._runs = runs_core.remove_run(self._runs, run_id) + self._runs = runs_core.add_run(self._runs, new_run) return new_run return update_method - def add_run(run: Run) -> Run: - nonlocal data - data = runs_core.add_run(data, run) + def add_run(self, run: Run) -> Run: + self._runs = runs_core.add_run(self._runs, run) return run - async def execute_decision(run: Run, probe_choices: List[Dict]) -> Run: - nonlocal data + def add_runs_bulk(self, runs: List[Run]) -> None: + self._runs = runs_core.add_runs_bulk(self._runs, runs) + + def populate_cache_bulk(self, runs: List[Run]) -> None: + self._runs = runs_core.populate_cache_bulk(self._runs, runs) + + async def _execute_with_cache(self, run: Run, probe_choices: List[Dict]) -> Run: cache_key = run.compute_cache_key() - cached = runs_core.get_cached_decision(data, cache_key) + cached = runs_core.get_cached_decision(self._runs, cache_key) if cached: updated_run = run.model_copy(update={"decision": cached}) - data = runs_core.add_run(data, updated_run) + self._runs = runs_core.add_run(self._runs, updated_run) return updated_run decision = await runs_core.fetch_decision(run, probe_choices) updated_run = run.model_copy(update={"decision": decision}) - data = runs_core.add_run(data, updated_run) - data = runs_core.add_cached_decision(data, cache_key, decision) + self._runs = runs_core.add_run(self._runs, updated_run) + self._runs = runs_core.add_cached_decision(self._runs, cache_key, decision) return updated_run - async def execute_run_decision(run_id: str) -> Optional[Run]: - nonlocal data + async def execute_decision(self, run: Run, probe_choices: List[Dict]) -> Run: + return await self._execute_with_cache(run, probe_choices) - run = runs_core.get_run(data, run_id) + async def execute_run_decision(self, run_id: str) -> Optional[Run]: + run = runs_core.get_run(self._runs, run_id) if not run: return None - probe = probe_registry.get_probe(run.probe_id) + probe = self._probe_registry.get_probe(run.probe_id) if not probe: return None - probe_choices = probe.choices or [] - cache_key = run.compute_cache_key() - - cached = runs_core.get_cached_decision(data, cache_key) - if cached: - updated_run = run.model_copy(update={"decision": cached}) - data = runs_core.add_run(data, updated_run) - return updated_run - - decision = await runs_core.fetch_decision(run, probe_choices) - - updated_run = run.model_copy(update={"decision": decision}) - data = runs_core.add_run(data, updated_run) - data = runs_core.add_cached_decision(data, cache_key, decision) - return updated_run + return await self._execute_with_cache(run, probe.choices or []) - async def create_and_execute_run(run: Run, probe_choices: List[Dict]): - nonlocal data - cache_key = run.compute_cache_key() + def get_run(self, run_id: str) -> Optional[Run]: + run = runs_core.get_run(self._runs, run_id) + if run: + run = runs_core.apply_cached_decision(self._runs, run) + return run - cached = runs_core.get_cached_decision(data, cache_key) - if cached: - updated_run = run.model_copy(update={"decision": cached}) - data = runs_core.add_run(data, updated_run) - return data, updated_run + def get_all_runs(self) -> Dict[str, Run]: + return dict(runs_core.get_all_runs_with_cached_decisions(self._runs)) + + def clear_runs(self): + self._runs = runs_core.clear_runs(self._runs) + return self._runs + + def clear_all(self): + self._runs = runs_core.init_runs() + self._experiment_items = {} + + def update_run_scene(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method(runs_edit_logic.prepare_scene_update)( + run_id, value + ) + + def update_run_scenario(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method(runs_edit_logic.prepare_scenario_update)( + run_id, value + ) + + def update_run_decider(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method(runs_edit_logic.prepare_decider_update)( + run_id, value + ) + + def update_run_llm_backbone(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method(runs_edit_logic.prepare_llm_update)( + run_id, value + ) + + def add_run_alignment_attribute(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method( + runs_edit_logic.prepare_add_alignment_attribute + )(run_id, value) + + def update_run_alignment_attribute_value( + self, run_id: str, value: Any + ) -> Optional[Run]: + return self._create_update_method( + runs_edit_logic.prepare_update_alignment_attribute_value + )(run_id, value) + + def update_run_alignment_attribute_score( + self, run_id: str, value: Any + ) -> Optional[Run]: + return self._create_update_method( + runs_edit_logic.prepare_update_alignment_attribute_score + )(run_id, value) + + def delete_run_alignment_attribute(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method( + runs_edit_logic.prepare_delete_alignment_attribute + )(run_id, value) + + def update_run_probe_text(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method(runs_edit_logic.prepare_update_probe_text)( + run_id, value + ) + + def update_run_choice_text(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method(runs_edit_logic.prepare_update_choice_text)( + run_id, value + ) + + def add_run_choice(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method(runs_edit_logic.prepare_add_run_choice)( + run_id, value + ) + + def delete_run_choice(self, run_id: str, value: Any) -> Optional[Run]: + return self._create_update_method(runs_edit_logic.prepare_delete_run_choice)( + run_id, value + ) + + def add_experiment_items(self, items: Dict[str, StoredExperimentItem]): + """Add experiment items (keyed by cache_key).""" + self._experiment_items = {**self._experiment_items, **items} + + def get_experiment_item(self, cache_key: str) -> Optional[StoredExperimentItem]: + return self._experiment_items.get(cache_key) + + def get_all_experiment_items(self) -> Dict[str, StoredExperimentItem]: + return self._experiment_items + + def materialize_experiment_item(self, cache_key: str) -> Optional[Run]: + """Convert experiment item to Run on demand. Populates decision cache.""" + stored = self._experiment_items.get(cache_key) + if not stored: + return None + run = run_from_stored_experiment_item(stored) + if run: + self._runs = runs_core.add_run(self._runs, run) + return run - decision = await runs_core.fetch_decision(run, probe_choices) - updated_run = run.model_copy(update={"decision": decision}) - data = runs_core.add_run(data, updated_run) - data = runs_core.add_cached_decision(data, cache_key, decision) - return data, updated_run - - def get_run(run_id: str) -> Optional[Run]: - return runs_core.get_run(data, run_id) - - def get_all_runs() -> Dict[str, Run]: - return dict(runs_core.get_all_runs_with_cached_decisions(data)) - - def clear_runs(): - nonlocal data - data = runs_core.clear_runs(data) - return data - - update_run_scene = _create_update_method(runs_edit_logic.prepare_scene_update) - update_run_scenario = _create_update_method(runs_edit_logic.prepare_scenario_update) - update_run_decider = _create_update_method(runs_edit_logic.prepare_decider_update) - update_run_llm_backbone = _create_update_method(runs_edit_logic.prepare_llm_update) - add_run_alignment_attribute = _create_update_method( - runs_edit_logic.prepare_add_alignment_attribute - ) - update_run_alignment_attribute_value = _create_update_method( - runs_edit_logic.prepare_update_alignment_attribute_value - ) - update_run_alignment_attribute_score = _create_update_method( - runs_edit_logic.prepare_update_alignment_attribute_score - ) - delete_run_alignment_attribute = _create_update_method( - runs_edit_logic.prepare_delete_alignment_attribute - ) - update_run_probe_text = _create_update_method( - runs_edit_logic.prepare_update_probe_text - ) - update_run_choice_text = _create_update_method( - runs_edit_logic.prepare_update_choice_text - ) - add_run_choice = _create_update_method(runs_edit_logic.prepare_add_run_choice) - delete_run_choice = _create_update_method(runs_edit_logic.prepare_delete_run_choice) - - return RunsRegistry( - add_run=add_run, - execute_decision=execute_decision, - execute_run_decision=execute_run_decision, - create_and_execute_run=create_and_execute_run, - get_run=get_run, - get_all_runs=get_all_runs, - clear_runs=clear_runs, - update_run_scene=update_run_scene, - update_run_scenario=update_run_scenario, - update_run_decider=update_run_decider, - update_run_llm_backbone=update_run_llm_backbone, - add_run_alignment_attribute=add_run_alignment_attribute, - update_run_alignment_attribute_value=update_run_alignment_attribute_value, - update_run_alignment_attribute_score=update_run_alignment_attribute_score, - delete_run_alignment_attribute=delete_run_alignment_attribute, - update_run_probe_text=update_run_probe_text, - update_run_choice_text=update_run_choice_text, - add_run_choice=add_run_choice, - delete_run_choice=delete_run_choice, - ) + def get_run_by_cache_key(self, cache_key: str) -> Optional[Run]: + """Find run by cache_key.""" + for run in self._runs.runs.values(): + if run.compute_cache_key() == cache_key: + return runs_core.apply_cached_decision(self._runs, run) + return None diff --git a/align_app/app/runs_state_adapter.py b/align_app/app/runs_state_adapter.py index 25c4457..8d2e8a7 100644 --- a/align_app/app/runs_state_adapter.py +++ b/align_app/app/runs_state_adapter.py @@ -1,12 +1,16 @@ from typing import Dict, Optional from trame.app import asynchronous -from trame.decorators import TrameApp, controller, change -from .run_models import Run +from trame.app.file_upload import ClientFile +from trame.decorators import TrameApp, controller, change, trigger +from ..adm.run_models import Run from .runs_registry import RunsRegistry +from .runs_table_filter import RunsTableFilter from ..adm.decider.types import DeciderParams from ..utils.utils import get_id from .runs_presentation import extract_base_scenarios from . import runs_presentation +from .export_experiments import export_runs_to_zip +from .import_experiments import import_experiments_from_zip from align_utils.models import AlignmentTarget @@ -20,6 +24,20 @@ def __init__( self.probe_registry = probe_registry self.decider_registry = decider_registry self.server.state.pending_cache_keys = [] + self.server.state.runs_table_modal_open = False + self.server.state.runs_table_selected = [] + self.server.state.runs_table_search = "" + self.server.state.runs_table_headers = [ + {"title": "Scenario", "key": "scenario_id"}, + {"title": "Scene", "key": "scene_id"}, + {"title": "Situation", "key": "probe_text", "sortable": False}, + {"title": "Decider", "key": "decider_name"}, + {"title": "LLM", "key": "llm_backbone_name"}, + {"title": "Alignment", "key": "alignment_summary"}, + {"title": "Decision", "key": "decision_text"}, + ] + self.server.state.import_experiment_file = None + self.table_filter = RunsTableFilter(server) self._sync_from_runs_data(runs_registry.get_all_runs()) @property @@ -27,12 +45,30 @@ def state(self): return self.server.state def _sync_from_runs_data(self, runs_dict: Dict[str, Run]): - self.state.runs = { - run_id: runs_presentation.run_to_state_dict( + new_runs = {} + for run_id, run in runs_dict.items(): + new_run = runs_presentation.run_to_state_dict( run, self.probe_registry, self.decider_registry ) - for run_id, run in runs_dict.items() - } + new_runs[run_id] = new_run + self.state.runs = new_runs + + run_table_rows = [ + runs_presentation.run_to_table_row(run_dict) + for run_dict in new_runs.values() + ] + + active_cache_keys = {run.compute_cache_key() for run in runs_dict.values()} + stored_items = self.runs_registry.get_all_experiment_items() + experiment_table_rows = [ + runs_presentation.experiment_item_to_table_row( + stored.item, stored.cache_key + ) + for cache_key, stored in stored_items.items() + if cache_key not in active_cache_keys + ] + + self.table_filter.set_all_rows(run_table_rows + experiment_table_rows) probes = self.probe_registry.get_probes() self.state.base_scenarios = extract_base_scenarios(probes) @@ -48,6 +84,14 @@ def reset_state(self): self._sync_from_runs_data({}) self.create_default_run() + @controller.set("clear_all_runs") + def clear_all_runs(self): + self.runs_registry.clear_all() + self._sync_from_runs_data({}) + self.state.runs_to_compare = [] + self.state.runs_table_selected = [] + self.create_default_run() + def create_default_run(self): probes = self.probe_registry.get_probes() if not probes: @@ -233,6 +277,12 @@ def update_run_choice_text(self, run_id: str, index: int, text: str): choices[index]["unstructured"] = text self.state.dirty("runs") + @controller.set("update_run_config_yaml") + def update_run_config_yaml(self, run_id: str, yaml_text: str): + if run_id in self.state.runs: + self.state.runs[run_id]["prompt"]["resolved_config_yaml"] = yaml_text + self.state.dirty("runs") + @controller.set("add_run_choice") def add_run_choice(self, run_id: str): new_run = self.runs_registry.add_run_choice(run_id, None) @@ -249,9 +299,17 @@ def check_probe_edited(self, run_id: str): return new_probe_id = self._create_edited_probe_for_run(run_id) - new_probe = self.probe_registry.get_probe(new_probe_id) + if not new_probe_id: + return + run = self.runs_registry.get_run(run_id) + if not run: + return + + if new_probe_id == run.probe_id: + return + new_probe = self.probe_registry.get_probe(new_probe_id) updated_params = run.decider_params.model_copy( update={"scenario_input": new_probe.item.input} ) @@ -271,6 +329,13 @@ def check_probe_edited(self, run_id: str): ] self._sync_from_runs_data(self.runs_registry.get_all_runs()) + @controller.set("check_config_edited") + def check_config_edited(self, run_id: str): + if not self._is_config_edited(run_id): + return + + self._create_run_with_edited_config(run_id) + def _is_probe_edited(self, run_id: str) -> bool: """Check if UI state differs from original probe.""" run = self.runs_registry.get_run(run_id) @@ -304,9 +369,11 @@ def _is_probe_edited(self, run_id: str) -> bool: return False - def _create_edited_probe_for_run(self, run_id: str) -> str: + def _create_edited_probe_for_run(self, run_id: str) -> Optional[str]: """Create new probe from UI state edited content. Returns new probe_id.""" run = self.runs_registry.get_run(run_id) + if not run: + return None ui_run = self.state.runs[run_id] edited_text = ui_run["prompt"]["probe"].get("display_state", "") edited_choices = list(ui_run["prompt"]["probe"].get("choices", [])) @@ -316,6 +383,73 @@ def _create_edited_probe_for_run(self, run_id: str) -> str: ) return new_probe.probe_id + def _is_config_edited(self, run_id: str) -> bool: + """Check if UI config YAML differs from original resolved_config.""" + run = self.runs_registry.get_run(run_id) + if not run or run_id not in self.state.runs: + return False + + ui_yaml = self.state.runs[run_id]["prompt"].get("resolved_config_yaml", "") + original_yaml = runs_presentation.resolved_config_to_yaml( + run.decider_params.resolved_config + ) + return ui_yaml != original_yaml + + def _create_run_with_edited_config(self, run_id: str) -> Optional[str]: + """Create new run with edited config. Returns new run_id, or None if no change.""" + import yaml + from align_app.adm.decider_registry import _get_root_decider_name + + run = self.runs_registry.get_run(run_id) + if not run: + return None + ui_yaml = self.state.runs[run_id]["prompt"]["resolved_config_yaml"] + new_config = yaml.safe_load(ui_yaml) + + decider_options = self.decider_registry.get_decider_options( + run.probe_id, run.decider_name + ) + llm_backbones = ( + decider_options.get("llm_backbones", []) if decider_options else [] + ) + + root_decider_name = _get_root_decider_name(run.decider_name) + root_config = self.decider_registry.get_decider_config( + probe_id=run.probe_id, + decider=root_decider_name, + llm_backbone=run.llm_backbone_name, + ) + + if root_config == new_config: + new_decider_name = root_decider_name + else: + new_decider_name = self.decider_registry.add_edited_decider( + run.decider_name, new_config, llm_backbones + ) + + if new_decider_name == run.decider_name: + return None + + updated_params = run.decider_params.model_copy( + update={"resolved_config": new_config} + ) + new_run_id = get_id() + new_run = run.model_copy( + update={ + "id": new_run_id, + "decider_name": new_decider_name, + "decider_params": updated_params, + "decision": None, + } + ) + self.runs_registry.add_run(new_run) + + self.state.runs_to_compare = [ + new_run_id if rid == run_id else rid for rid in self.state.runs_to_compare + ] + self._sync_from_runs_data(self.runs_registry.get_all_runs()) + return new_run_id + def _add_pending_cache_key(self, cache_key: str): if cache_key and cache_key not in self.state.pending_cache_keys: self.state.pending_cache_keys = [*self.state.pending_cache_keys, cache_key] @@ -329,7 +463,11 @@ def _remove_pending_cache_key(self, cache_key: str): async def _execute_run_decision(self, run_id: str): if self._is_probe_edited(run_id): new_probe_id = self._create_edited_probe_for_run(run_id) + if not new_probe_id: + return run = self.runs_registry.get_run(run_id) + if not run: + return updated_run = run.model_copy(update={"probe_id": new_probe_id}) self.runs_registry.add_run(updated_run) self._sync_run_to_state(updated_run) @@ -346,7 +484,7 @@ async def _execute_run_decision(self, run_id: str): await self.server.network_completion - updated_run = await self.runs_registry.execute_run_decision(run_id) + await self.runs_registry.execute_run_decision(run_id) with self.state: all_runs = self.runs_registry.get_all_runs() @@ -360,8 +498,103 @@ def execute_run_decision(self, run_id: str): def export_runs_to_json(self) -> str: return runs_presentation.export_runs_to_json(self.state.runs) + @trigger("export_runs_zip") + def trigger_export_runs_zip(self) -> bytes: + return export_runs_to_zip(self.state.runs) + + @trigger("export_selected_runs_zip") + def trigger_export_selected_runs_zip(self) -> bytes: + selected = self.state.runs_table_selected + if not selected: + return b"" + + selected_runs = {} + for item in selected: + cache_key = item["id"] if isinstance(item, dict) else item + run = self.runs_registry.get_run_by_cache_key(cache_key) + if not run: + run = self.runs_registry.materialize_experiment_item(cache_key) + if run: + run_dict = runs_presentation.run_to_state_dict( + run, self.probe_registry, self.decider_registry + ) + selected_runs[run.id] = run_dict + + return export_runs_to_zip(selected_runs) + + @controller.set("update_runs_table_selected") + def update_runs_table_selected(self, selected): + self.state.runs_table_selected = selected if selected else [] + + @controller.set("open_runs_table_modal") + def open_runs_table_modal(self): + self.state.runs_table_modal_open = True + + @controller.set("close_runs_table_modal") + def close_runs_table_modal(self): + self.state.runs_table_modal_open = False + self.state.runs_table_selected = [] + + @controller.set("add_selected_runs_to_compare") + def add_selected_runs_to_compare(self): + selected = self.state.runs_table_selected + if not selected: + return + + existing = list(self.state.runs_to_compare) + + for item in selected: + cache_key = item["id"] if isinstance(item, dict) else item + + run = self.runs_registry.get_run_by_cache_key(cache_key) + + if not run: + run = self.runs_registry.materialize_experiment_item(cache_key) + + if run and run.id not in existing: + existing.append(run.id) + + self.state.runs_to_compare = existing + self.state.runs_table_modal_open = False + self.state.runs_table_selected = [] + self._sync_from_runs_data(self.runs_registry.get_all_runs()) + + @controller.set("on_table_row_click") + def on_table_row_click(self, _event, item): + cache_key = item.get("id") if isinstance(item, dict) else item + if not cache_key: + return + + run = self.runs_registry.get_run_by_cache_key(cache_key) + + if not run: + run = self.runs_registry.materialize_experiment_item(cache_key) + + if run and run.id not in self.state.runs_to_compare: + self.state.runs_to_compare = [*self.state.runs_to_compare, run.id] + self._sync_from_runs_data(self.runs_registry.get_all_runs()) + @change("runs") def update_runs_json(self, **_): json_data = self.export_runs_to_json() self.state.runs_json = json_data self.state.flush() + + @change("import_experiment_file") + def on_import_experiment_file(self, import_experiment_file, **_): + if import_experiment_file is None: + return + + file = ClientFile(import_experiment_file) + if not file.content: + return + + result = import_experiments_from_zip(file.content) + + self.probe_registry.add_probes(result.probes) + self.decider_registry.add_deciders(result.deciders) + self.runs_registry.add_experiment_items(result.items) + + self._sync_from_runs_data(self.runs_registry.get_all_runs()) + + self.state.import_experiment_file = None diff --git a/align_app/app/runs_table_filter.py b/align_app/app/runs_table_filter.py new file mode 100644 index 0000000..b9aff03 --- /dev/null +++ b/align_app/app/runs_table_filter.py @@ -0,0 +1,114 @@ +import re +from typing import List, Dict, Any, Tuple +from trame.decorators import TrameApp, change + + +def natural_sort_key(s: str) -> list: + return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", s)] + + +FILTER_COLUMNS = [ + ("runs_table_filter_scenario", "scenario_id"), + ("runs_table_filter_scene", "scene_id"), + ("runs_table_filter_decider", "decider_name"), + ("runs_table_filter_llm", "llm_backbone_name"), + ("runs_table_filter_alignment", "alignment_summary"), + ("runs_table_filter_decision", "decision_text"), +] + + +def compute_filter_options( + rows: List[Dict[str, Any]], +) -> Dict[str, List[str]]: + return { + "runs_table_scenario_options": sorted( + set(r["scenario_id"] for r in rows if r.get("scenario_id")), + key=natural_sort_key, + ), + "runs_table_scene_options": sorted( + set(r["scene_id"] for r in rows if r.get("scene_id")), + key=natural_sort_key, + ), + "runs_table_decider_options": sorted( + set(r["decider_name"] for r in rows if r.get("decider_name")), + key=natural_sort_key, + ), + "runs_table_llm_options": sorted( + set(r["llm_backbone_name"] for r in rows if r.get("llm_backbone_name")), + key=natural_sort_key, + ), + "runs_table_alignment_options": sorted( + set(r["alignment_summary"] for r in rows if r.get("alignment_summary")), + key=natural_sort_key, + ), + "runs_table_decision_options": sorted( + set(r["decision_text"] for r in rows if r.get("decision_text")), + key=natural_sort_key, + ), + } + + +def filter_rows( + rows: List[Dict[str, Any]], + filters: List[Tuple[List[str], str]], +) -> List[Dict[str, Any]]: + def row_matches(row: Dict[str, Any]) -> bool: + for filter_values, key in filters: + if filter_values and row.get(key) not in filter_values: + return False + return True + + return [r for r in rows if row_matches(r)] + + +@TrameApp() +class RunsTableFilter: + def __init__(self, server): + self.server = server + self._all_rows: List[Dict[str, Any]] = [] + + self.state.runs_table_filter_scenario = [] + self.state.runs_table_filter_scene = [] + self.state.runs_table_filter_decider = [] + self.state.runs_table_filter_llm = [] + self.state.runs_table_filter_alignment = [] + self.state.runs_table_filter_decision = [] + + self.state.runs_table_scenario_options = [] + self.state.runs_table_scene_options = [] + self.state.runs_table_decider_options = [] + self.state.runs_table_llm_options = [] + self.state.runs_table_alignment_options = [] + self.state.runs_table_decision_options = [] + + @property + def state(self): + return self.server.state + + def set_all_rows(self, rows: List[Dict[str, Any]]): + self._all_rows = rows + self._update_filter_options() + self._apply_filters() + + def _update_filter_options(self): + options = compute_filter_options(self._all_rows) + for key, value in options.items(): + setattr(self.state, key, value) + + @change( + "runs_table_filter_scenario", + "runs_table_filter_scene", + "runs_table_filter_decider", + "runs_table_filter_llm", + "runs_table_filter_alignment", + "runs_table_filter_decision", + ) + def _on_filter_change(self, **kwargs): + self._apply_filters() + + def _apply_filters(self): + filters = [ + (getattr(self.state, state_key) or [], col_key) + for state_key, col_key in FILTER_COLUMNS + ] + self.state.runs_table_items = filter_rows(self._all_rows, filters) diff --git a/align_app/app/search.py b/align_app/app/search.py index 07373bd..3974019 100644 --- a/align_app/app/search.py +++ b/align_app/app/search.py @@ -1,3 +1,4 @@ +from typing import Optional, Tuple from trame.decorators import TrameApp, controller from rapidfuzz import fuzz, process, utils from ..adm.probe import Probe @@ -8,10 +9,10 @@ class SearchController: """Controller for search functionality with dropdown menu.""" - def __init__(self, server, probe_registry): + def __init__(self, server, probe_registry, on_search_select=None): self.server = server self.probe_registry = probe_registry - self.runs_state_adapter = None + self._on_search_select = on_search_select self.server.state.search_query = "" self.server.state.search_results = [] self.server.state.search_menu_open = False @@ -20,9 +21,6 @@ def __init__(self, server, probe_registry): debounce(0.2, self.server.state)(self.update_search_results) ) - def set_runs_state_adapter(self, runs_state_adapter): - self.runs_state_adapter = runs_state_adapter - def _create_search_result(self, probe_id, probe: Probe): display_state = probe.display_state or "" display_text = display_state.split("\n")[0] if display_state else "" @@ -76,15 +74,19 @@ def update_search_results(self, search_query, **_): ] self.server.state.search_menu_open = True + def _get_search_selection(self, index: int) -> Optional[Tuple[str, str]]: + """Extract scenario_id and scene_id from search result at index.""" + if not (0 <= index < len(self.server.state.search_results)): + return None + result = self.server.state.search_results[index] + if result.get("id") is None: + return None + return (result.get("scenario_id"), result.get("scene_id")) + @controller.add("select_run_search_result") def select_run_search_result(self, run_id, index): - if 0 <= index < len(self.server.state.search_results): - result = self.server.state.search_results[index] - if result.get("id") is not None and self.runs_state_adapter: - new_run_id = self.runs_state_adapter.update_run_scenario( - run_id, result.get("scenario_id") - ) - self.runs_state_adapter.update_run_scene( - new_run_id, result.get("scene_id") - ) - self.server.state.run_search_expanded_id = None + selection = self._get_search_selection(index) + if selection and self._on_search_select: + scenario_id, scene_id = selection + self._on_search_select(run_id, scenario_id, scene_id) + self.server.state.run_search_expanded_id = None diff --git a/align_app/app/ui.py b/align_app/app/ui.py index bb527d9..dd998dd 100644 --- a/align_app/app/ui.py +++ b/align_app/app/ui.py @@ -1,11 +1,6 @@ -from typing import Any, Dict, cast -import copy from trame.ui.vuetify3 import SinglePageLayout from trame.widgets import vuetify3, html -from ..adm.types import Prompt, SerializedPrompt, SerializedAlignmentTarget -from ..adm.probe import Probe as ProbeModel -from ..utils.utils import noop, readable, readable_sentence, sentence_lines -from .prompt_logic import get_alignment_descriptions_map +from ..utils.utils import noop, readable, readable_sentence from .unordered_object import ( UnorderedObject, ValueWithProgressBar, @@ -15,39 +10,15 @@ ) -def serialize_prompt(prompt: Prompt) -> SerializedPrompt: - """Serialize a prompt for JSON/state storage, removing non-serializable fields. - - This is THE serialization boundary - converts Probe to dict for UI state. - Input: prompt["probe"] is Probe model - Output: prompt["probe"] is dict - """ - probe: ProbeModel = prompt["probe"] - alignment_target = cast( - SerializedAlignmentTarget, prompt["alignment_target"].model_dump() - ) - - system_prompt: str = prompt.get("system_prompt", "") # type: ignore[assignment] - result: SerializedPrompt = { - "probe": probe.to_dict(), - "alignment_target": alignment_target, - "decider_params": prompt["decider_params"], - "system_prompt": system_prompt, - } - - return copy.deepcopy(result) - - def reload(m=None): if m: m.__loader__.exec_module(m) -SENTENCE_KEYS = ["intent", "unstructured"] # Keys to apply sentence function to - RUN_COLUMN_MIN_WIDTH = "28rem" LABEL_COLUMN_WIDTH = "12rem" INDICATOR_SPACE = "3rem" +PENDING_SPINNER_CONDITION = "pending_cache_keys.includes(runs[id].cache_key)" TITLE_TRUNCATE_STYLE = ( "overflow: hidden; text-overflow: ellipsis; " f"white-space: nowrap; width: calc(100% - {INDICATOR_SPACE});" @@ -108,58 +79,6 @@ def __init__(self, alignment_info_expr, **kwargs): PlainObjectProperty("value") -def readable_probe(probe): - full_state = probe.full_state or {} - characters = full_state.get("characters", []) - readable_characters = [ - {**c, **{key: sentence_lines(c[key]) for key in SENTENCE_KEYS if key in c}} - for c in characters - ] - - return { - "probe_id": probe.probe_id, - "scene_id": probe.scene_id, - "scenario_id": probe.scenario_id, - "display_state": probe.display_state, - "full_state": {**full_state, "characters": readable_characters}, - "choices": probe.choices, - "state": probe.state, - } - - -def readable_attribute(kdma_value, descriptions): - return { - **kdma_value, - "description": descriptions.get(kdma_value.get("kdma"), {}).get( - "description", - f"No description for {kdma_value.get('kdma')}", - ), - "kdma": readable(kdma_value.get("kdma")), - "value": round(kdma_value.get("value"), 2), - } - - -def prep_for_state(prompt: Prompt): - descriptions = get_alignment_descriptions_map(prompt) - p = serialize_prompt(prompt) - result: Dict[str, Any] = { - **p, - "alignment_target": { - **p["alignment_target"], - "kdma_values": [ - readable_attribute(a, descriptions) - for a in p["alignment_target"]["kdma_values"] - ], - }, - "decider_params": { - **p["decider_params"], - "decider": readable(p["decider_params"]["decider"]), - }, - "probe": readable_probe(prompt["probe"]), - } - return result - - def make_keys_readable(obj, max_depth=2, current_depth=0): if current_depth >= max_depth or not isinstance(obj, dict): return obj @@ -388,16 +307,20 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def run_content(): - vuetify3.VSelect( - label="Decider", - items=("runs[id].decider_items",), - model_value=("runs[id].prompt.decider_params.decider",), - update_modelValue=( - self.server.controller.update_run_decider, - r"[id, $event]", - ), - hide_details="auto", - ) + with html.Div( + style=f"width: calc(100% - {INDICATOR_SPACE});", + raw_attrs=["@click.stop", "@mousedown.stop"], + ): + vuetify3.VSelect( + label="Decider", + items=("runs[id].decider_items",), + model_value=("runs[id].prompt.decider_params.decider",), + update_modelValue=( + self.server.controller.update_run_decider, + r"[id, $event]", + ), + hide_details="auto", + ) RowWithLabel( run_content=run_content, @@ -405,6 +328,35 @@ def run_content(): compare_expr="runs[id].prompt.decider_params.decider", ) + class Text(html.Template): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def run_content(): + ctrl = self.server.controller + with html.Div(style="align-self: flex-start; width: 100%;"): + vuetify3.VTextarea( + model_value=("runs[id].prompt.resolved_config_yaml",), + update_modelValue=( + ctrl.update_run_config_yaml, + r"[id, $event]", + ), + blur=(ctrl.check_config_edited, "[id]"), + auto_grow=True, + rows=1, + variant="outlined", + density="compact", + hide_details="auto", + classes="config-textarea", + style="font-family: monospace; font-size: 0.85em;", + ) + + RowWithLabel( + run_content=run_content, + label="Config", + compare_expr="runs[id].prompt.resolved_config_yaml", + ) + class Alignment: COMPARE_EXPR = "runs[id].alignment_attributes" @@ -541,15 +493,6 @@ def run_content(): ) -class ProbeLayout: - def __init__(self, probe): - html.Div("Situation", classes="text-h6") - html.P(f"{{{{{probe}.display_state}}}}", style="white-space: pre-wrap;") - html.Div("Choices", classes="text-h6 pt-4") - with html.Ol(classes="ml-8", type="A"): - html.Li("{{choice.unstructured}}", v_for=(f"choice in {probe}.choices")) - - class EditableProbeLayoutForRun: def __init__(self, server): ctrl = server.controller @@ -685,7 +628,7 @@ def render_run_decision(): ) with html.Template(v_else=True): vuetify3.VProgressCircular( - v_if=("pending_cache_keys.includes(runs[id].cache_key)",), + v_if=(PENDING_SPINNER_CONDITION,), indeterminate=True, size=20, ) @@ -730,7 +673,7 @@ def render_choice_info(): ) with html.Template(v_else=True): vuetify3.VProgressCircular( - v_if=("pending_cache_keys.includes(runs[id].cache_key)",), + v_if=(PENDING_SPINNER_CONDITION,), indeterminate=True, size=20, ) @@ -825,6 +768,177 @@ def no_runs(): RowWithLabel(run_content=run_content, label="Run Number", no_runs=no_runs) +def sortable_filter_header(key: str, title: str, filter_var: str, options_var: str): + """Create a sortable column header with filter dropdown.""" + with vuetify3.Template( + raw_attrs=[ + f'v-slot:header.{key}="{{ column, isSorted, getSortIcon, toggleSort }}"' + ], + ): + with html.Div( + classes="d-flex align-center cursor-pointer", + raw_attrs=["@click='toggleSort(column)'"], + ): + html.Span(title, classes="text-subtitle-2") + vuetify3.VIcon( + raw_attrs=[ + "v-if='isSorted(column)'", + ":icon='getSortIcon(column)'", + ], + size="small", + classes="ml-1", + ) + vuetify3.VSelect( + v_model=(filter_var,), + items=(options_var,), + clearable=True, + multiple=True, + density="compact", + hide_details=True, + raw_attrs=["@click.stop", "@mousedown.stop"], + ) + + +def cell_with_tooltip(key: str): + """Create cell template with native title tooltip.""" + with html.Template(raw_attrs=[f'v-slot:item.{key}="{{ item }}"']): + html.Span(f"{{{{ item.{key} }}}}", v_bind_title=f"item.{key}") + + +def filterable_column(key: str, title: str, filter_var: str, options_var: str): + """Create sortable column header with filter and cell tooltip.""" + sortable_filter_header(key, title, filter_var, options_var) + cell_with_tooltip(key) + + +class RunsTableModal(html.Div): + def __init__(self, **kwargs): + super().__init__(**kwargs) + with self: + with vuetify3.VDialog( + v_model=("runs_table_modal_open",), + fullscreen=True, + ): + with vuetify3.VCard(): + with vuetify3.VToolbar(density="compact"): + vuetify3.VToolbarTitle("Runs") + vuetify3.VSpacer() + with vuetify3.VBtn( + click=( + self.server.controller.add_selected_runs_to_compare, + ), + disabled=("runs_table_selected.length === 0",), + prepend_icon="mdi-plus", + classes="mr-4", + ): + html.Span("Add Selected to Comparison") + with vuetify3.VBtn( + click=( + "utils.download('align-app-experiments.zip', " + "trigger('export_selected_runs_zip'), 'application/zip')" + ), + disabled=("runs_table_selected.length === 0",), + prepend_icon="mdi-content-save", + classes="mr-4", + ): + html.Span("Save Selected") + vuetify3.VFileInput( + v_model=("import_experiment_file", None), + accept=".zip", + ref="tableImportFileInput", + style="display: none;", + ) + with vuetify3.VBtn( + click=( + "trame.refs.tableImportFileInput.$el" + ".querySelector('input').click()" + ), + prepend_icon="mdi-upload", + classes="mr-4", + ): + html.Span("Load Experiments") + with vuetify3.VBtn( + click=self.server.controller.clear_all_runs, + prepend_icon="mdi-delete-sweep", + classes="mr-4", + ): + html.Span("Clear All") + vuetify3.VTextField( + v_model=("runs_table_search",), + placeholder="Search", + prepend_inner_icon="mdi-magnify", + clearable=True, + hide_details=True, + density="compact", + style="max-width: 300px;", + classes="mr-4", + ) + with vuetify3.VBtn( + icon=True, + click=(self.server.controller.close_runs_table_modal,), + ): + vuetify3.VIcon("mdi-close") + with vuetify3.VCardText( + classes="pa-0", + style="height: calc(100vh - 64px); overflow: auto;", + ): + with vuetify3.VDataTable( + items=("runs_table_items",), + headers=("runs_table_headers",), + model_value=("runs_table_selected",), + update_modelValue=( + self.server.controller.update_runs_table_selected, + "[$event]", + ), + item_value="id", + show_select=True, + hover=True, + search=("runs_table_search",), + items_per_page=(100,), + click_row=( + self.server.controller.on_table_row_click, + "[$event, item]", + ), + ): + filterable_column( + "scenario_id", + "Scenario", + "runs_table_filter_scenario", + "runs_table_scenario_options", + ) + filterable_column( + "scene_id", + "Scene", + "runs_table_filter_scene", + "runs_table_scene_options", + ) + cell_with_tooltip("probe_text") + filterable_column( + "decider_name", + "Decider", + "runs_table_filter_decider", + "runs_table_decider_options", + ) + filterable_column( + "llm_backbone_name", + "LLM", + "runs_table_filter_llm", + "runs_table_llm_options", + ) + filterable_column( + "alignment_summary", + "Alignment", + "runs_table_filter_alignment", + "runs_table_alignment_options", + ) + filterable_column( + "decision_text", + "Decision", + "runs_table_filter_decision", + "runs_table_decision_options", + ) + + class ResultsComparison(html.Div): def __init__(self, **kwargs): super().__init__(classes="d-inline-flex flex-wrap ga-4 pa-1", **kwargs) @@ -840,20 +954,6 @@ def __init__(self, **kwargs): PanelSection(child=ChoiceInfo) -class ProbePanel(vuetify3.VExpansionPanel): - def __init__(self, probe, **kwargs): - super().__init__(**kwargs) - with self: - with vuetify3.VExpansionPanelTitle(): - with html.Div(classes="text-subtitle-1 text-no-wrap text-truncate"): - html.Span( - f"{{{{{probe}.probe_id}}}} - " - f"{{{{{probe}.full_state.unstructured}}}}", - ) - with vuetify3.VExpansionPanelText(): - ProbeLayout(probe) - - class RunSearchField(html.Div): """Search field for runs with dropdown results menu.""" @@ -931,11 +1031,28 @@ def __init__( with vuetify3.VBtn(icon=True, click=reload): vuetify3.VIcon("mdi-refresh") with vuetify3.VBtn( - click="utils.download('align-app-runs.json', runs_json || '[]', 'application/json')", + click=self.server.controller.open_runs_table_modal, + disabled=("Object.keys(runs).length === 0",), + prepend_icon="mdi-table", + ): + html.Span("Browse Runs") + vuetify3.VFileInput( + v_model=("import_experiment_file", None), + accept=".zip", + ref="importFileInput", + style="display: none;", + ) + with vuetify3.VBtn( + click="trame.refs.importFileInput.$el.querySelector('input').click()", + prepend_icon="mdi-upload", + ): + html.Span("Load Experiments") + with vuetify3.VBtn( + click="utils.download('align-app-experiments.zip', trigger('export_runs_zip'), 'application/zip')", disabled=("Object.keys(runs).length === 0",), - prepend_icon="mdi-file-download", + prepend_icon="mdi-content-save", ): - html.Span("Export Runs") + html.Span("Save Experiments") with vuetify3.VBtn( click=self.server.controller.reset_state, prepend_icon="mdi-delete-sweep", @@ -949,6 +1066,10 @@ def __init__( "html { overflow: hidden !important; }" ".v-textarea .v-field__input { overflow-y: hidden !important; }" ".v-expansion-panel { max-width: none !important; }" + ".config-textarea textarea { white-space: pre; overflow-x: auto; }" + ".v-data-table table { table-layout: fixed; width: 100%; }" + ".v-data-table td { overflow: hidden; text-overflow: ellipsis; white-space: nowrap; }" + ".v-data-table th { vertical-align: top; }" "'" ) ) @@ -961,3 +1082,4 @@ def __init__( style="min-width: 100%; width: fit-content; padding: 16px;", ): ResultsComparison() + RunsTableModal() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..bebe038 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,40 @@ +"""Shared pytest fixtures for align-app tests.""" + +import zipfile +from pathlib import Path +from urllib.request import urlretrieve + +import pytest + +EXPERIMENTS_ZIP_URL = ( + "https://github.com/ITM-Kitware/align-app/releases/download/v1.0.0/experiments.zip" +) + +REPO_ROOT = Path(__file__).parent.parent +FIXTURES_CACHE_DIR = REPO_ROOT / "tests" / "fixtures" / ".cache" + + +@pytest.fixture(scope="session") +def experiments_fixtures_path() -> Path: + """Download and cache experiment fixtures for testing. + + Downloads experiments.zip from GitHub releases, extracts it, and caches + the result in tests/fixtures/.cache/ (gitignored). + """ + FIXTURES_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + experiments_dir = FIXTURES_CACHE_DIR / "experiments" + zip_path = FIXTURES_CACHE_DIR / "experiments.zip" + + if experiments_dir.exists() and any(experiments_dir.iterdir()): + return experiments_dir + + if not zip_path.exists(): + print(f"\nDownloading experiment fixtures from {EXPERIMENTS_ZIP_URL}...") + urlretrieve(EXPERIMENTS_ZIP_URL, zip_path) + + print(f"\nExtracting experiment fixtures to {experiments_dir}...") + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(FIXTURES_CACHE_DIR) + + return experiments_dir diff --git a/tests/e2e/page_objects/align_page.py b/tests/e2e/page_objects/align_page.py index 06a01bd..eec1c86 100644 --- a/tests/e2e/page_objects/align_page.py +++ b/tests/e2e/page_objects/align_page.py @@ -160,12 +160,57 @@ def select_results_llm(self, llm_name: str) -> None: option.click() expect(listbox).not_to_be_visible() + @property + def scenario_panel(self) -> Locator: + return ( + self.page.locator(".v-expansion-panel") + .filter(has=self.page.get_by_role("button").filter(has_text="Scenario")) + .first + ) + @property def scenario_panel_title(self) -> Locator: - return self.page.get_by_role("button", name="Scenario") + return ( + self.page.get_by_role("button") + .filter(has_text="Scenario") + .filter(has=self.page.locator(".v-expansion-panel-title__overlay")) + ) + + @property + def scenario_panel_content(self) -> Locator: + return self.scenario_panel.locator(".v-expansion-panel-text") + + @property + def situation_textarea(self) -> Locator: + return self.scenario_panel_content.locator( + ".v-textarea textarea:not(.v-textarea__sizer)" + ).first def expand_scenario_panel(self) -> None: expect(self.scenario_panel_title).to_be_visible() + is_expanded = self.scenario_panel_title.get_attribute("aria-expanded") == "true" + if not is_expanded: + expand_icon = self.scenario_panel_title.locator(".mdi-chevron-down") + expand_icon.click() + expect(self.scenario_panel_title).to_have_attribute("aria-expanded", "true") + expect(self.scenario_panel_content).to_be_visible() + + def get_situation_text(self) -> str: + expect(self.situation_textarea).to_be_visible() + value = self.situation_textarea.input_value() + return value if value else "" + + def set_situation_text(self, text: str) -> None: + expect(self.situation_textarea).to_be_visible() + self.situation_textarea.fill(text) + + def blur_situation_textarea(self) -> None: + self.situation_textarea.blur() + + def get_scene_dropdown_value_from_panel(self) -> str: + dropdown = self.scenario_panel.get_by_role("combobox").filter(has_text="Scene") + value = dropdown.locator("input").input_value() + return value if value else "" @property def alignment_panel_title(self) -> Locator: @@ -247,3 +292,55 @@ def find_decider_in_open_list( if item_text and not any(excl in item_text for excl in exclude): return decider_items.nth(i), item_text return None, None + + @property + def decider_panel(self) -> Locator: + return ( + self.page.locator(".v-expansion-panel") + .filter(has=self.page.get_by_role("button").filter(has_text="Decider")) + .first + ) + + @property + def decider_panel_title(self) -> Locator: + return ( + self.page.get_by_role("button") + .filter(has_text="Decider") + .filter(has=self.page.locator(".v-expansion-panel-title__overlay")) + ) + + @property + def decider_panel_content(self) -> Locator: + return self.decider_panel.locator(".v-expansion-panel-text") + + @property + def config_textarea(self) -> Locator: + return self.decider_panel_content.locator( + ".config-textarea textarea:not(.v-textarea__sizer)" + ) + + def expand_decider_panel(self) -> None: + expect(self.decider_panel_title).to_be_visible() + is_expanded = self.decider_panel_title.get_attribute("aria-expanded") == "true" + if not is_expanded: + expand_icon = self.decider_panel_title.locator(".mdi-chevron-down") + expand_icon.click() + expect(self.decider_panel_title).to_have_attribute("aria-expanded", "true") + expect(self.decider_panel_content).to_be_visible() + + def get_config_yaml(self) -> str: + expect(self.config_textarea).to_be_visible() + value = self.config_textarea.input_value() + return value if value else "" + + def set_config_yaml(self, yaml_text: str) -> None: + expect(self.config_textarea).to_be_visible() + self.config_textarea.fill(yaml_text) + + def blur_config_textarea(self) -> None: + self.config_textarea.blur() + + def get_decider_dropdown_value(self) -> str: + dropdown = self.decider_panel.get_by_role("combobox").filter(has_text="Decider") + value = dropdown.locator("input").input_value() + return value if value else "" diff --git a/tests/e2e/test_config_edit.py b/tests/e2e/test_config_edit.py new file mode 100644 index 0000000..7d1b9c4 --- /dev/null +++ b/tests/e2e/test_config_edit.py @@ -0,0 +1,104 @@ +from .page_objects.align_page import AlignPage + + +def test_config_edit_creates_new_decider(page, align_app_server): + """Test that editing config YAML creates a new decider with edit suffix.""" + align_page = AlignPage(page) + align_page.goto(align_app_server) + + align_page.expand_decider_panel() + + original_decider = align_page.get_decider_dropdown_value() + original_config = align_page.get_config_yaml() + assert original_config, "Config should not be empty" + assert " - edit " not in original_decider, "Should start with non-edited decider" + + modified_config = original_config + "\ntest_key: test_value" + align_page.set_config_yaml(modified_config) + align_page.blur_config_textarea() + + page.wait_for_timeout(500) + + new_decider = align_page.get_decider_dropdown_value() + assert " - edit " in new_decider, ( + f"Expected decider to have ' - edit ' suffix after config change. " + f"Original: {original_decider}, New: {new_decider}" + ) + assert new_decider.startswith(original_decider.split(" - edit ")[0]), ( + f"New decider should be based on original. " + f"Original: {original_decider}, New: {new_decider}" + ) + + +def test_config_edit_revert_restores_original_decider(page, align_app_server): + """Test that reverting config to original restores the original decider.""" + align_page = AlignPage(page) + align_page.goto(align_app_server) + + align_page.expand_decider_panel() + + original_decider = align_page.get_decider_dropdown_value() + original_config = align_page.get_config_yaml() + assert original_config, "Config should not be empty" + + modified_config = original_config + "\ntest_key: test_value" + align_page.set_config_yaml(modified_config) + align_page.blur_config_textarea() + + page.wait_for_timeout(500) + + edited_decider = align_page.get_decider_dropdown_value() + assert " - edit " in edited_decider, ( + "Decider should have edit suffix after modification" + ) + + align_page.set_config_yaml(original_config) + align_page.blur_config_textarea() + + page.wait_for_timeout(500) + + reverted_decider = align_page.get_decider_dropdown_value() + assert reverted_decider == original_decider, ( + f"Expected decider to revert to original after restoring config. " + f"Original: {original_decider}, Reverted: {reverted_decider}" + ) + + +def test_config_edit_to_existing_config_reuses_decider(page, align_app_server): + """Test that editing config to match an existing config reuses that decider.""" + align_page = AlignPage(page) + align_page.goto(align_app_server) + + align_page.expand_decider_panel() + + original_config = align_page.get_config_yaml() + + modified_config_1 = original_config + "\nfirst_key: first_value" + align_page.set_config_yaml(modified_config_1) + align_page.blur_config_textarea() + page.wait_for_timeout(500) + + first_edited_decider = align_page.get_decider_dropdown_value() + assert " - edit 1" in first_edited_decider, ( + "First edit should create '- edit 1' decider" + ) + + modified_config_2 = original_config + "\nsecond_key: second_value" + align_page.set_config_yaml(modified_config_2) + align_page.blur_config_textarea() + page.wait_for_timeout(500) + + second_edited_decider = align_page.get_decider_dropdown_value() + assert " - edit 2" in second_edited_decider, ( + "Second edit should create '- edit 2' decider" + ) + + align_page.set_config_yaml(modified_config_1) + align_page.blur_config_textarea() + page.wait_for_timeout(500) + + reused_decider = align_page.get_decider_dropdown_value() + assert reused_decider == first_edited_decider, ( + f"Expected decider to reuse first edited decider. " + f"First edit: {first_edited_decider}, Reused: {reused_decider}" + ) diff --git a/tests/e2e/test_scenario_edit.py b/tests/e2e/test_scenario_edit.py new file mode 100644 index 0000000..943ea93 --- /dev/null +++ b/tests/e2e/test_scenario_edit.py @@ -0,0 +1,58 @@ +from .page_objects.align_page import AlignPage + + +def test_situation_text_edit_creates_new_scene(page, align_app_server): + """Test that editing situation text creates a new scene with edit suffix.""" + align_page = AlignPage(page) + align_page.goto(align_app_server) + + align_page.expand_scenario_panel() + + original_scene = align_page.get_scene_dropdown_value_from_panel() + original_text = align_page.get_situation_text() + assert original_text, "Situation text should not be empty" + assert " edit " not in original_scene, "Should start with non-edited scene" + + modified_text = original_text + " [test modification]" + align_page.set_situation_text(modified_text) + align_page.blur_situation_textarea() + + page.wait_for_timeout(500) + + new_scene = align_page.get_scene_dropdown_value_from_panel() + assert " edit " in new_scene, ( + f"Expected scene to have ' edit ' suffix after text change. " + f"Original: {original_scene}, New: {new_scene}" + ) + + +def test_situation_text_revert_restores_original_scene(page, align_app_server): + """Test that reverting situation text to original restores the original scene.""" + align_page = AlignPage(page) + align_page.goto(align_app_server) + + align_page.expand_scenario_panel() + + original_scene = align_page.get_scene_dropdown_value_from_panel() + original_text = align_page.get_situation_text() + assert original_text, "Situation text should not be empty" + + modified_text = original_text + " [test modification]" + align_page.set_situation_text(modified_text) + align_page.blur_situation_textarea() + + page.wait_for_timeout(500) + + edited_scene = align_page.get_scene_dropdown_value_from_panel() + assert " edit " in edited_scene, "Scene should have edit suffix after modification" + + align_page.set_situation_text(original_text) + align_page.blur_situation_textarea() + + page.wait_for_timeout(500) + + reverted_scene = align_page.get_scene_dropdown_value_from_panel() + assert reverted_scene == original_scene, ( + f"Expected scene to revert to original after restoring text. " + f"Original: {original_scene}, Reverted: {reverted_scene}" + ) diff --git a/tests/test_experiment_deciders.py b/tests/test_experiment_deciders.py new file mode 100644 index 0000000..1a1ff3f --- /dev/null +++ b/tests/test_experiment_deciders.py @@ -0,0 +1,145 @@ +"""Tests for experiment decider loading and instantiation.""" + +from pathlib import Path + + +def test_experiment_fixtures_download(experiments_fixtures_path: Path): + """Verify experiment fixtures are downloaded and extracted.""" + assert experiments_fixtures_path.exists() + assert experiments_fixtures_path.is_dir() + + experiment_dirs = list(experiments_fixtures_path.iterdir()) + assert len(experiment_dirs) >= 3, "Expected at least 3 experiment types" + + +def test_experiment_registry_loads_items(experiments_fixtures_path: Path): + """Verify experiment results registry loads all experiment items.""" + from align_app.adm.experiment_results_registry import ( + create_experiment_results_registry, + ) + + registry = create_experiment_results_registry(experiments_fixtures_path) + + all_items = registry.get_all_items() + assert len(all_items) > 0, "Expected experiment items to be loaded" + + +def test_unique_deciders_extracted(experiments_fixtures_path: Path): + """Verify unique deciders are extracted from experiments.""" + from align_app.adm.experiment_converters import deciders_from_experiments + from align_app.adm.experiment_results_registry import ( + create_experiment_results_registry, + ) + + registry = create_experiment_results_registry(experiments_fixtures_path) + + experiment_deciders = deciders_from_experiments(registry.get_experiments()) + assert len(experiment_deciders) == 3, ( + f"Expected 3 unique deciders, got {len(experiment_deciders)}" + ) + + expected_deciders = { + "pipeline_fewshot_comparative_regression_loo_20icl", + "pipeline_baseline", + "pipeline_baseline_greedy_w_cache", + } + assert set(experiment_deciders.keys()) == expected_deciders + + +def test_experiment_decider_config_loading(experiments_fixtures_path: Path): + """Verify experiment decider configs can be loaded.""" + from align_app.adm.experiment_config_loader import load_experiment_adm_config + from align_app.adm.experiment_converters import deciders_from_experiments + from align_app.adm.experiment_results_registry import ( + create_experiment_results_registry, + ) + + registry = create_experiment_results_registry(experiments_fixtures_path) + experiment_deciders = deciders_from_experiments(registry.get_experiments()) + + for name, entry in experiment_deciders.items(): + config = load_experiment_adm_config(Path(entry["experiment_path"])) + + assert config is not None, f"Config for {name} could not be loaded" + assert "instance" in config, f"Config for {name} missing 'instance'" + assert config["instance"].get("_target_"), ( + f"Config for {name} missing instance._target_" + ) + + +def test_decider_registry_includes_experiment_deciders(experiments_fixtures_path: Path): + """Verify decider registry includes experiment deciders.""" + from align_app.adm.decider_registry import create_decider_registry + from align_app.adm.experiment_converters import ( + deciders_from_experiments, + probes_from_experiment_items, + ) + from align_app.adm.experiment_results_registry import ( + create_experiment_results_registry, + ) + from align_app.adm.probe_registry import create_probe_registry + + exp_registry = create_experiment_results_registry(experiments_fixtures_path) + + probe_registry = create_probe_registry(scenarios_paths=[]) + experiment_probes = probes_from_experiment_items(exp_registry.get_all_items()) + probe_registry.add_probes(experiment_probes) + + experiment_deciders = deciders_from_experiments(exp_registry.get_experiments()) + + decider_registry = create_decider_registry( + config_paths=[], + scenario_registry=probe_registry, + experiment_deciders=experiment_deciders, + ) + + all_deciders = decider_registry.get_all_deciders() + + for exp_name in experiment_deciders: + assert exp_name in all_deciders, ( + f"Experiment decider {exp_name} not in registry" + ) + assert all_deciders[exp_name].get("experiment_config") is True + + +def test_get_decider_config_for_experiment(experiments_fixtures_path: Path): + """Verify get_decider_config works for experiment deciders.""" + from align_app.adm.decider_registry import create_decider_registry + from align_app.adm.experiment_converters import ( + deciders_from_experiments, + probes_from_experiment_items, + ) + from align_app.adm.experiment_results_registry import ( + create_experiment_results_registry, + ) + from align_app.adm.probe_registry import create_probe_registry + + exp_registry = create_experiment_results_registry(experiments_fixtures_path) + + probe_registry = create_probe_registry(scenarios_paths=[]) + experiment_probes = probes_from_experiment_items(exp_registry.get_all_items()) + probe_registry.add_probes(experiment_probes) + + experiment_deciders = deciders_from_experiments(exp_registry.get_experiments()) + + decider_registry = create_decider_registry( + config_paths=[], + scenario_registry=probe_registry, + experiment_deciders=experiment_deciders, + ) + + probes = probe_registry.get_probes() + probe_id = next(iter(probes.keys())) + exp_decider_name = "pipeline_fewshot_comparative_regression_loo_20icl" + + config = decider_registry.get_decider_config( + probe_id=probe_id, + decider=exp_decider_name, + ) + + assert config is not None + assert "instance" in config + assert ( + config["instance"]["_target_"] + == "align_system.algorithms.pipeline_adm.PipelineADM" + ) diff --git a/tests/unit/test_cache_backend.py b/tests/unit/test_cache_backend.py index 579ee89..133fb3e 100644 --- a/tests/unit/test_cache_backend.py +++ b/tests/unit/test_cache_backend.py @@ -2,7 +2,7 @@ def test_hash_run_params_generates_consistent_key(): - from align_app.app.run_models import hash_run_params + from align_app.adm.run_models import hash_run_params from align_app.adm.decider.types import DeciderParams scenario_input = InputData( @@ -42,7 +42,7 @@ def test_hash_run_params_generates_consistent_key(): def test_hash_run_params_different_for_changed_params(): - from align_app.app.run_models import hash_run_params + from align_app.adm.run_models import hash_run_params from align_app.adm.decider.types import DeciderParams scenario_input = InputData( diff --git a/tests/unit/test_experiment_cache.py b/tests/unit/test_experiment_cache.py new file mode 100644 index 0000000..b02520c --- /dev/null +++ b/tests/unit/test_experiment_cache.py @@ -0,0 +1,212 @@ +"""Tests for experiment cache population and retrieval.""" + +from pathlib import Path +import pytest + +from align_app.adm.run_models import Run +from align_app.adm.decider.types import DeciderParams +from align_app.app.runs_registry import RunsRegistry +from align_app.app.runs_core import Runs, populate_cache_bulk +from align_app.adm.experiment_converters import runs_from_experiment_items +from align_utils.models import AlignmentTarget, KDMAValue, get_experiment_items +from align_utils.discovery import parse_experiments_directory + + +@pytest.fixture +def single_experiment_path(experiments_fixtures_path: Path) -> Path: + """Return path to a single experiment for faster tests.""" + return ( + experiments_fixtures_path + / "pipeline_baseline_greedy_w_cache" + / "affiliation-0.0" + ) + + +def test_cache_populated_from_experiment(single_experiment_path: Path): + """Verify cache is populated from experiment items.""" + experiments = parse_experiments_directory(single_experiment_path.parent.parent) + + target_experiments = [ + exp + for exp in experiments + if "pipeline_baseline_greedy_w_cache" in str(exp.experiment_path) + and "affiliation-0.0" in str(exp.experiment_path) + ] + assert len(target_experiments) == 1 + + items = get_experiment_items(target_experiments[0]) + assert len(items) > 0 + + runs = runs_from_experiment_items(items) + assert len(runs) > 0 + assert all(run.decision is not None for run in runs) + + data = Runs.empty() + data = populate_cache_bulk(data, runs) + + assert len(data.decision_cache) == len(runs) + + +def test_cache_hit_with_matching_params(single_experiment_path: Path): + """Verify cache returns decision when params match.""" + experiments = parse_experiments_directory(single_experiment_path.parent.parent) + + target_experiments = [ + exp + for exp in experiments + if "pipeline_baseline_greedy_w_cache" in str(exp.experiment_path) + and "affiliation-0.0" in str(exp.experiment_path) + ] + + items = get_experiment_items(target_experiments[0]) + cached_runs = runs_from_experiment_items(items[:1]) + cached_run = cached_runs[0] + + class MockProbeRegistry: + def get_probe(self, probe_id): + return None + + def get_probes(self): + return {} + + class MockDeciderRegistry: + def get_system_prompt(self, **kwargs): + return "" + + runs_registry = RunsRegistry(MockProbeRegistry(), MockDeciderRegistry()) + runs_registry.populate_cache_bulk(cached_runs) + + ui_alignment_target = AlignmentTarget( + id="affiliation-0.0", kdma_values=[KDMAValue(kdma="affiliation", value=0.0)] + ) + + ui_run = Run( + id="ui-test-run", + probe_id=cached_run.probe_id, + decider_name=cached_run.decider_name, + llm_backbone_name=cached_run.llm_backbone_name, + system_prompt="", + decider_params=DeciderParams( + scenario_input=cached_run.decider_params.scenario_input, + alignment_target=ui_alignment_target, + resolved_config=cached_run.decider_params.resolved_config, + ), + ) + + runs_registry.add_run(ui_run) + fetched = runs_registry.get_run("ui-test-run") + + assert fetched is not None + assert fetched.decision is not None + assert ( + fetched.decision.adm_result.decision.unstructured + == cached_run.decision.adm_result.decision.unstructured + ) + + +def test_cache_miss_with_different_params(single_experiment_path: Path): + """Verify cache returns None when params don't match.""" + experiments = parse_experiments_directory(single_experiment_path.parent.parent) + + target_experiments = [ + exp + for exp in experiments + if "pipeline_baseline_greedy_w_cache" in str(exp.experiment_path) + and "affiliation-0.0" in str(exp.experiment_path) + ] + + items = get_experiment_items(target_experiments[0]) + cached_runs = runs_from_experiment_items(items[:1]) + cached_run = cached_runs[0] + + class MockProbeRegistry: + def get_probe(self, probe_id): + return None + + def get_probes(self): + return {} + + class MockDeciderRegistry: + def get_system_prompt(self, **kwargs): + return "" + + runs_registry = RunsRegistry(MockProbeRegistry(), MockDeciderRegistry()) + runs_registry.populate_cache_bulk(cached_runs) + + different_alignment_target = AlignmentTarget( + id="affiliation-0.5", kdma_values=[KDMAValue(kdma="affiliation", value=0.5)] + ) + + ui_run = Run( + id="ui-test-run-miss", + probe_id=cached_run.probe_id, + decider_name=cached_run.decider_name, + llm_backbone_name=cached_run.llm_backbone_name, + system_prompt="", + decider_params=DeciderParams( + scenario_input=cached_run.decider_params.scenario_input, + alignment_target=different_alignment_target, + resolved_config=cached_run.decider_params.resolved_config, + ), + ) + + runs_registry.add_run(ui_run) + fetched = runs_registry.get_run("ui-test-run-miss") + + assert fetched is not None + assert fetched.decision is None + + +def test_cache_preserved_after_clear_runs(single_experiment_path: Path): + """Verify cache is preserved when runs are cleared.""" + experiments = parse_experiments_directory(single_experiment_path.parent.parent) + + target_experiments = [ + exp + for exp in experiments + if "pipeline_baseline_greedy_w_cache" in str(exp.experiment_path) + and "affiliation-0.0" in str(exp.experiment_path) + ] + + items = get_experiment_items(target_experiments[0]) + cached_runs = runs_from_experiment_items(items[:1]) + cached_run = cached_runs[0] + + class MockProbeRegistry: + def get_probe(self, probe_id): + return None + + def get_probes(self): + return {} + + class MockDeciderRegistry: + def get_system_prompt(self, **kwargs): + return "" + + runs_registry = RunsRegistry(MockProbeRegistry(), MockDeciderRegistry()) + runs_registry.populate_cache_bulk(cached_runs) + + runs_registry.clear_runs() + + ui_alignment_target = AlignmentTarget( + id="affiliation-0.0", kdma_values=[KDMAValue(kdma="affiliation", value=0.0)] + ) + + ui_run = Run( + id="ui-after-clear", + probe_id=cached_run.probe_id, + decider_name=cached_run.decider_name, + llm_backbone_name=cached_run.llm_backbone_name, + system_prompt="", + decider_params=DeciderParams( + scenario_input=cached_run.decider_params.scenario_input, + alignment_target=ui_alignment_target, + resolved_config=cached_run.decider_params.resolved_config, + ), + ) + + runs_registry.add_run(ui_run) + fetched = runs_registry.get_run("ui-after-clear") + + assert fetched is not None + assert fetched.decision is not None diff --git a/tests/unit/test_runs_table_filter.py b/tests/unit/test_runs_table_filter.py new file mode 100644 index 0000000..9a36af4 --- /dev/null +++ b/tests/unit/test_runs_table_filter.py @@ -0,0 +1,162 @@ +from align_app.app.runs_table_filter import compute_filter_options, filter_rows + + +def test_compute_filter_options_extracts_unique_values(): + rows = [ + { + "scenario_id": "scenario_a", + "scene_id": "scene_1", + "decider_name": "decider_x", + "llm_backbone_name": "llm_1", + "alignment_summary": "aligned", + "decision_text": "choice A", + }, + { + "scenario_id": "scenario_b", + "scene_id": "scene_2", + "decider_name": "decider_y", + "llm_backbone_name": "llm_2", + "alignment_summary": "not aligned", + "decision_text": "choice B", + }, + ] + + options = compute_filter_options(rows) + + assert options["runs_table_scenario_options"] == ["scenario_a", "scenario_b"] + assert options["runs_table_scene_options"] == ["scene_1", "scene_2"] + assert options["runs_table_decider_options"] == ["decider_x", "decider_y"] + assert options["runs_table_llm_options"] == ["llm_1", "llm_2"] + assert options["runs_table_alignment_options"] == ["aligned", "not aligned"] + assert options["runs_table_decision_options"] == ["choice A", "choice B"] + + +def test_compute_filter_options_sorts_values(): + rows = [ + { + "scenario_id": "zebra", + "scene_id": "1", + "decider_name": "x", + "llm_backbone_name": "l", + "alignment_summary": "a", + "decision_text": "d", + }, + { + "scenario_id": "alpha", + "scene_id": "2", + "decider_name": "y", + "llm_backbone_name": "l", + "alignment_summary": "a", + "decision_text": "d", + }, + ] + + options = compute_filter_options(rows) + + assert options["runs_table_scenario_options"] == ["alpha", "zebra"] + + +def test_compute_filter_options_deduplicates_values(): + rows = [ + { + "scenario_id": "same", + "scene_id": "1", + "decider_name": "x", + "llm_backbone_name": "l", + "alignment_summary": "a", + "decision_text": "d", + }, + { + "scenario_id": "same", + "scene_id": "2", + "decider_name": "y", + "llm_backbone_name": "l", + "alignment_summary": "a", + "decision_text": "d", + }, + ] + + options = compute_filter_options(rows) + + assert options["runs_table_scenario_options"] == ["same"] + + +def test_filter_rows_empty_filters_returns_all(): + rows = [ + {"scenario_id": "a", "scene_id": "1"}, + {"scenario_id": "b", "scene_id": "2"}, + ] + + filters = [([], "scenario_id"), ([], "scene_id")] + result = filter_rows(rows, filters) + + assert len(result) == 2 + + +def test_filter_rows_single_filter(): + rows = [ + {"scenario_id": "a", "scene_id": "1"}, + {"scenario_id": "b", "scene_id": "2"}, + {"scenario_id": "a", "scene_id": "3"}, + ] + + filters = [(["a"], "scenario_id")] + result = filter_rows(rows, filters) + + assert len(result) == 2 + assert all(r["scenario_id"] == "a" for r in result) + + +def test_filter_rows_multiple_values_in_filter(): + rows = [ + {"scenario_id": "a", "scene_id": "1"}, + {"scenario_id": "b", "scene_id": "2"}, + {"scenario_id": "c", "scene_id": "3"}, + ] + + filters = [(["a", "b"], "scenario_id")] + result = filter_rows(rows, filters) + + assert len(result) == 2 + scenarios = {r["scenario_id"] for r in result} + assert scenarios == {"a", "b"} + + +def test_filter_rows_multiple_filters_combine_with_and(): + rows = [ + {"scenario_id": "a", "scene_id": "1"}, + {"scenario_id": "a", "scene_id": "2"}, + {"scenario_id": "b", "scene_id": "1"}, + ] + + filters = [(["a"], "scenario_id"), (["1"], "scene_id")] + result = filter_rows(rows, filters) + + assert len(result) == 1 + assert result[0]["scenario_id"] == "a" + assert result[0]["scene_id"] == "1" + + +def test_filter_rows_no_matches_returns_empty(): + rows = [ + {"scenario_id": "a", "scene_id": "1"}, + {"scenario_id": "b", "scene_id": "2"}, + ] + + filters = [(["nonexistent"], "scenario_id")] + result = filter_rows(rows, filters) + + assert len(result) == 0 + + +def test_filter_rows_handles_missing_keys(): + rows = [ + {"scenario_id": "a"}, + {"scenario_id": "b", "scene_id": "1"}, + ] + + filters = [(["1"], "scene_id")] + result = filter_rows(rows, filters) + + assert len(result) == 1 + assert result[0]["scenario_id"] == "b"