diff --git a/align_system/cli/run_align_system.py b/align_system/cli/run_align_system.py index 7bb7d08c..c24876d8 100644 --- a/align_system/cli/run_align_system.py +++ b/align_system/cli/run_align_system.py @@ -3,6 +3,7 @@ import atexit import os import yaml +import re from rich.logging import RichHandler from rich.console import Console @@ -25,8 +26,14 @@ def main(cfg: DictConfig) -> None: cfg = instantiate(cfg, recursive=True) + if 'adm_mappings' in cfg: + adm_mappings = cfg.adm_mappings + else: + adm_mappings = [{'scenario_pattern': '.*', + 'alignment_target_pattern': '.*', + 'adm': cfg.adm}] + interface = cfg.interface - adm = cfg.adm.instance # Using the hydra generated output directory for the run output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir @@ -98,12 +105,6 @@ def main(cfg: DictConfig) -> None: else: sort_available_actions = False - # HACK: need to invoke 'load_model' for ADMs that require it, - # maybe it makes more sense to load_model in the init method for - # those ADMs - if hasattr(adm, 'load_model'): - adm.load_model() - # Capture inputs and outputs in a similar format to what's used by # our internal evaluation framework code inputs_outputs = [] @@ -149,6 +150,18 @@ def _compute_time_stats(times_s): with open(alignment_target_path, "w") as f: yaml.dump(alignment_target.to_dict(), f) + adm = None + for adm_mapping in adm_mappings: + if(re.match(adm_mapping['scenario_pattern'], scenario.id()) and + re.match(adm_mapping['alignment_target_pattern'], alignment_target.id)): + adm = adm_mapping['adm'] + log.info(f"[bold]*Mapped scenario/target to ADM: {adm}*[/bold]", + extra={"markup": True}) + break + + if adm is None: + raise RuntimeError("Couldn't find an appropriate ADM in adm_mappings") + current_state = scenario.get_state() scenario_complete = current_state.scenario_complete @@ -272,11 +285,11 @@ def _compute_time_stats(times_s): # prevent ADMs from modifying the originals (should # considering doing the same for current_state and # alignment_target) - action_to_take = adm.choose_action( + action_to_take = adm.instance.choose_action( current_state, [deepcopy(a) for a in available_actions_filtered], alignment_target if cfg.align_to_target else None, - **cfg.adm.get('inference_kwargs', {})) + **adm.get('inference_kwargs', {})) end_choose_action = timer() sce_times_s.append(end_choose_action - start_choose_action) diff --git a/align_system/configs/experiment/dry_run_evaluation/multi_adm_test_eval.yaml b/align_system/configs/experiment/dry_run_evaluation/multi_adm_test_eval.yaml new file mode 100644 index 00000000..fe33d31c --- /dev/null +++ b/align_system/configs/experiment/dry_run_evaluation/multi_adm_test_eval.yaml @@ -0,0 +1,30 @@ +# @package _global_ +defaults: + - /adm@adm1: random + - /adm@adm2: outlines_transformers_structured_baseline + - override /interface: ta3 + +interface: + api_endpoint: "http://127.0.0.1:8089" + session_type: eval + training_session: false + username: "ALIGN-ADM-Random-Multi-Baseline" + +align_to_target: true +save_last_unstructured_state_per_scenario: true + +adm2: + instance: + precision: half + +adm_mappings: + - scenario_pattern: '^qol.*' + alignment_target_pattern: '^qol.*' + adm: ${adm2} + - scenario_pattern: '.*' + alignment_target_pattern: '.*' + adm: ${adm1} + +hydra: + run: + dir: 'random_eval_live/${now:%Y-%m-%d__%H-%M-%S}' diff --git a/align_system/configs/multi_adm.yaml b/align_system/configs/multi_adm.yaml new file mode 100644 index 00000000..004b342a --- /dev/null +++ b/align_system/configs/multi_adm.yaml @@ -0,0 +1,23 @@ +name: multi_adm + +defaults: + - _self_ + - interface: ta3 + - adm@adm1: random + - override hydra/job_logging: custom + +loglevel: "EXPLAIN" + +save_log: true +save_input_output: true +save_scoring_output: true +save_alignment_targets: false +save_timing: true +save_last_unstructured_state_per_scenario: true + +align_to_target: true + +adm_mappings: + - scenario_pattern: '.*' + alignment_target_pattern: '.*' + adm: ${adm1}