Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ run.bash
venv/
__pycache__/
outputs

.vscode/
135 changes: 98 additions & 37 deletions align_system/algorithms/outlines_adm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import random
import itertools
import numpy as np
import torch
import yaml
import copy
from functools import partial

import outlines
from outlines.samplers import MultinomialSampler
Expand All @@ -25,7 +25,6 @@
calculate_votes,
filter_votes_to_responses,
)
from align_system.utils.hydrate_state import hydrate_scenario_state
from align_system.algorithms.abstracts import ActionBasedADM
from align_system.prompt_engineering.outlines_prompts import (
baseline_system_prompt,
Expand All @@ -46,8 +45,6 @@
tag_choice_json_schema,
treatment_choice_json_schema,
treatment_choice_from_list_json_schema,
detailed_unstructured_treatment_action_text,
detailed_unstructured_tagging_action_text,
high_risk_aversion_system_prompt,
low_risk_aversion_system_prompt,
high_continuing_care_system_prompt,
Expand Down Expand Up @@ -195,7 +192,8 @@ def kdma_value_to_system_prompt(kdma, value):
else:
return None

def _state_to_top_level_prompt(self, scenario_state, actions):
@staticmethod
def _static_state_to_top_level_prompt(action_selection_prompt_template, scenario_description, scenario_state, actions):
"""
Generate prompt dialog based on given state and actions
"""
Expand All @@ -205,11 +203,23 @@ def _state_to_top_level_prompt(self, scenario_state, actions):
scenario_state
)

scenario_description = self.scenario_description_template(scenario_state)
prompt = self.action_selection_prompt_template(scenario_description, choices)
prompt = action_selection_prompt_template(scenario_description, choices)

return prompt, choices

def _state_to_top_level_prompt(self, scenario_state, actions):
"""
Generate prompt dialog based on given state and actions
"""
scenario_description = self.scenario_description_template(scenario_state)
return OutlinesTransformersADM._static_state_to_top_level_prompt(
self.action_selection_prompt_template,
scenario_description,
scenario_state,
actions
)


# Function borrowed from
# https://docs.python.org/3/library/itertools.html#itertools.batched
# (since itertools.batched is only available in Python 3.12 or newer):
Expand All @@ -233,21 +243,25 @@ def run_in_batches(cls, inference_function, inputs, batch_size):
outputs.extend(output)
return outputs

def top_level_choose_action(self,
scenario_state,
available_actions,
alignment_target,
num_positive_samples=1,
num_negative_samples=0,
generator_batch_size=5,
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
**kwargs):
if self.baseline and num_negative_samples > 0:
@staticmethod
def get_dialogs(scenario_state,
available_actions,
alignment_target,
num_positive_samples=1,
num_negative_samples=0,
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
shuffle_choices=True,
baseline=False,
scenario_description_template=scenario_state_description_1,
action_selection_prompt_template=action_selection_prompt,
baseline_system_prompt=baseline_system_prompt,
**kwargs):
if baseline and num_negative_samples > 0:
raise RuntimeError("No notion of negative samples for baseline run")
if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
if baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
raise RuntimeError("No notion of incontext examples for baseline run")

scenario_description = self.scenario_description_template(scenario_state)
scenario_description = scenario_description_template(scenario_state)
# Important that the choices stay in the same order as the
# available actions as we'll use the selected index later to
# map to the corresponding action
Expand All @@ -260,12 +274,11 @@ def top_level_choose_action(self,
positive_icl_examples = []
negative_icl_examples = []
incontext_settings=kwargs.get("incontext", {})
if not self.baseline and alignment_target is not None:
kdma_values = alignment_target.kdma_values

if not baseline and alignment_target is not None:
kdma_values = alignment_target.kdma_values
if len(kdma_values) != 1:
raise RuntimeError("This ADM assumes a single KDMA target, aborting!")

kdma_value = kdma_values[0]
if isinstance(kdma_value, KDMAValue):
kdma_value = kdma_value.to_dict()
Expand All @@ -279,8 +292,8 @@ def top_level_choose_action(self,
kdma_descriptions = yaml.load(f, Loader=yaml.FullLoader)
name = kdma_descriptions[kdma]['name']

positive_system_prompt = self.__class__.kdma_value_to_system_prompt(kdma, value)
negative_system_prompt = self.__class__.kdma_value_to_system_prompt(kdma, negative_value)
positive_system_prompt = OutlinesTransformersADM.kdma_value_to_system_prompt(kdma, value)
negative_system_prompt = OutlinesTransformersADM.kdma_value_to_system_prompt(kdma, negative_value)
Comment on lines -282 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why you prefer OutlinesTransformersADM over self.__class__ here (is the former just less obtuse but does the same thing? Or is there some case where one approach will break). Not asking for a change here just want to pick your brain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made that function static, so no self handy. Mabye there is some snakey black magic to get a __class__ in static func?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed that bit (not sure about getting __class__ in a static method) let's leave as is


if positive_system_prompt is None:
raise RuntimeError("Couldn't find system prompt for kdma: {}, and "
Expand All @@ -290,8 +303,7 @@ def top_level_choose_action(self,
"value: {}.".format(kdma, negative_value))

if "incontext" in kwargs and "number" in incontext_settings and incontext_settings["number"] > 0:
scenario_to_match = self.scenario_description_template(scenario_state)
prompt_to_match, _ = self._state_to_top_level_prompt(scenario_state, available_actions)
prompt_to_match, _ = OutlinesTransformersADM._state_to_top_level_prompt(action_selection_prompt_template, scenario_state, available_actions)

# Create positive ICL example generators
positive_target = {'kdma': kdma, 'name': name, 'value': value}
Expand All @@ -300,7 +312,7 @@ def top_level_choose_action(self,
# Get subset of relevant of examples
positive_selected_icl_examples = positive_icl_example_generator.select_icl_examples(
sys_kdma_name=kdma,
scenario_description_to_match=scenario_to_match,
scenario_description_to_match=scenario_description,
prompt_to_match=prompt_to_match,
state_comparison=scenario_state
)
Expand All @@ -320,7 +332,7 @@ def top_level_choose_action(self,
# Get subset of relevant of examples
negative_selected_icl_examples = negative_icl_example_generator.select_icl_examples(
sys_kdma_name=kdma,
scenario_description_to_match=scenario_to_match,
scenario_description_to_match=scenario_description,
prompt_to_match=prompt_to_match,
state_comparison=scenario_state
)
Expand All @@ -330,17 +342,17 @@ def top_level_choose_action(self,
{"role": "assistant", "content": f'{icl_sample["response"]}'}
])
else:
positive_system_prompt = self.baseline_system_prompt()
positive_system_prompt = baseline_system_prompt()
if num_negative_samples > 0:
raise RuntimeError("No notion of negative samples for baseline run")
if "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
raise RuntimeError("No notion of incontext examples for baseline run")
negative_system_prompt = None # Not used in baseline

positive_dialogs = []
for _ in range(num_positive_samples):
shuffled_choices = random.sample(choices, len(choices))

prompt = self.action_selection_prompt_template(scenario_description, shuffled_choices)
shuf = random.sample(choices, len(choices)) if shuffle_choices else choices
prompt = action_selection_prompt(scenario_description, shuf)
dialog = [{'role': 'system', 'content': positive_system_prompt}]
dialog.extend(positive_icl_examples)
dialog.append({'role': 'user', 'content': prompt})
Expand All @@ -349,24 +361,73 @@ def top_level_choose_action(self,

negative_dialogs = []
for _ in range(num_negative_samples):
shuffled_choices = random.sample(choices, len(choices))

prompt = self.action_selection_prompt_template(scenario_description, shuffled_choices)
shuf = random.sample(choices, len(choices)) if shuffle_choices else choices
prompt = action_selection_prompt(scenario_description, shuf)
dialog = [{'role': 'system', 'content': negative_system_prompt}]
dialog.extend(negative_icl_examples)
dialog.append({'role': 'user', 'content': prompt})

negative_dialogs.append(dialog)

return {"scenario_description": scenario_description,
"choices": choices,
"positive_system_prompt": positive_system_prompt,
"negative_system_prompt": negative_system_prompt,
"positive_dialogs": positive_dialogs,
"negative_dialogs": negative_dialogs}

def top_level_choose_action(self,
scenario_state,
available_actions,
alignment_target,
num_positive_samples=1,
num_negative_samples=0,
generator_batch_size=5,
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
reasoning_max_length=512,
generator_seed=-1,
max_generator_tokens=-1,
shuffle_choices=True,
**kwargs):
if self.baseline and num_negative_samples > 0:
raise RuntimeError("No notion of negative samples for baseline run")
if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
raise RuntimeError("No notion of incontext examples for baseline run")

dialogs_data = OutlinesTransformersADM.get_dialogs(
scenario_state,
available_actions,
alignment_target,
num_positive_samples,
num_negative_samples,
kdma_descriptions_map,
shuffle_choices,
baseline=self.baseline,
scenario_description_template=self.scenario_description_template,
action_selection_prompt_template=self.action_selection_prompt_template,
baseline_system_prompt=self.baseline_system_prompt,
)
choices = dialogs_data["choices"]
positive_dialogs = dialogs_data["positive_dialogs"]
negative_dialogs = dialogs_data["negative_dialogs"]

# Need to set the whitespace_pattern to prevent the state
# machine from looping indefinitely in some cases, see:
# https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
generator = outlines.generate.json(
self.model,
action_choice_json_schema(json.dumps(choices)),
action_choice_json_schema(json.dumps(choices), reasoning_max_length),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

if max_generator_tokens >= 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting approach, normally I would set the default in the kwargs to None and then if it's not None pass the value along as-is; but I guess in this case (setting the default to -1) you get to do some semantic validation on the value at the same time. Is that kind of the intention or is this just a tomato vs tomato.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The -1 is a gross hack, and I regret it already =)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't say it's a hack, just curious about the tradeoffs; I'm content to leave it as is.

generator = partial(generator, max_tokens=max_generator_tokens)

if generator_seed >= 0:
torch.manual_seed(generator_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(generator_seed)


dialog_texts = [self.dialog_to_prompt(d) for d in
itertools.chain(positive_dialogs, negative_dialogs)]

Expand Down
9 changes: 6 additions & 3 deletions align_system/prompt_engineering/outlines_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,16 @@ def followup_clarify_aid(character, available_aids):


@outlines.prompt
def action_choice_json_schema(choices_json_str):
def action_choice_json_schema(choices_json_str, reasoning_max_length=512):
'''
{"$defs": {"ActionChoice": {"enum": {{ choices_json_str }},
"title": "ActionChoice",
"type": "string"}},
"properties": {"detailed_reasoning": {"title": "Detailed Reasoning",
"type": "string", "minLength": 1, "maxLength": 512},
"properties": {"detailed_reasoning": {
"title": "Detailed Reasoning",
"type": "string",
"minLength": 1{% if reasoning_max_length > 0 %}, "maxLength": {{ reasoning_max_length }}{% endif %}
},
"action_choice": {"$ref": "#/$defs/ActionChoice"}},
"required": ["detailed_reasoning", "action_choice"],
"title": "ActionSelection",
Expand Down