Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions align_system/cli/run_align_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import atexit
import os
import yaml
import re

from rich.logging import RichHandler
from rich.console import Console
Expand All @@ -25,8 +26,14 @@
def main(cfg: DictConfig) -> None:
cfg = instantiate(cfg, recursive=True)

if 'adm_mappings' in cfg:
adm_mappings = cfg.adm_mappings
else:
adm_mappings = [{'scenario_pattern': '.*',
'alignment_target_pattern': '.*',
'adm': cfg.adm}]

interface = cfg.interface
adm = cfg.adm.instance

# Using the hydra generated output directory for the run
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
Expand Down Expand Up @@ -98,12 +105,6 @@ def main(cfg: DictConfig) -> None:
else:
sort_available_actions = False

# HACK: need to invoke 'load_model' for ADMs that require it,
# maybe it makes more sense to load_model in the init method for
# those ADMs
if hasattr(adm, 'load_model'):
adm.load_model()

# Capture inputs and outputs in a similar format to what's used by
# our internal evaluation framework code
inputs_outputs = []
Expand Down Expand Up @@ -149,6 +150,18 @@ def _compute_time_stats(times_s):
with open(alignment_target_path, "w") as f:
yaml.dump(alignment_target.to_dict(), f)

adm = None
for adm_mapping in adm_mappings:
if(re.match(adm_mapping['scenario_pattern'], scenario.id()) and
re.match(adm_mapping['alignment_target_pattern'], alignment_target.id)):
adm = adm_mapping['adm']
log.info(f"[bold]*Mapped scenario/target to ADM: {adm}*[/bold]",
extra={"markup": True})
break

if adm is None:
raise RuntimeError("Couldn't find an appropriate ADM in adm_mappings")

current_state = scenario.get_state()
scenario_complete = current_state.scenario_complete

Expand Down Expand Up @@ -272,11 +285,11 @@ def _compute_time_stats(times_s):
# prevent ADMs from modifying the originals (should
# considering doing the same for current_state and
# alignment_target)
action_to_take = adm.choose_action(
action_to_take = adm.instance.choose_action(
current_state,
[deepcopy(a) for a in available_actions_filtered],
alignment_target if cfg.align_to_target else None,
**cfg.adm.get('inference_kwargs', {}))
**adm.get('inference_kwargs', {}))

end_choose_action = timer()
sce_times_s.append(end_choose_action - start_choose_action)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# @package _global_
defaults:
- /adm@adm1: random
- /adm@adm2: outlines_transformers_structured_baseline
- override /interface: ta3

interface:
api_endpoint: "http://127.0.0.1:8089"
session_type: eval
training_session: false
username: "ALIGN-ADM-Random-Multi-Baseline"

align_to_target: true
save_last_unstructured_state_per_scenario: true

adm2:
instance:
precision: half

adm_mappings:
- scenario_pattern: '^qol.*'
alignment_target_pattern: '^qol.*'
adm: ${adm2}
- scenario_pattern: '.*'
alignment_target_pattern: '.*'
adm: ${adm1}

hydra:
run:
dir: 'random_eval_live/${now:%Y-%m-%d__%H-%M-%S}'
23 changes: 23 additions & 0 deletions align_system/configs/multi_adm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: multi_adm

defaults:
- _self_
- interface: ta3
- adm@adm1: random
- override hydra/job_logging: custom

loglevel: "EXPLAIN"

save_log: true
save_input_output: true
save_scoring_output: true
save_alignment_targets: false
save_timing: true
save_last_unstructured_state_per_scenario: true

align_to_target: true

adm_mappings:
- scenario_pattern: '.*'
alignment_target_pattern: '.*'
adm: ${adm1}