Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 110 additions & 38 deletions align_system/algorithms/outlines_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
calculate_votes,
filter_votes_to_responses,
)
from align_system.utils.hydrate_state import hydrate_scenario_state
from align_system.algorithms.abstracts import ActionBasedADM
from align_system.prompt_engineering.outlines_prompts import (
baseline_system_prompt,
Expand All @@ -41,29 +40,33 @@
followup_clarify_treatment_from_list,
followup_clarify_tag,
action_choice_json_schema,
action_choice_json_schema_untrimmed,
aid_choice_json_schema,
character_choice_json_schema,
tag_choice_json_schema,
treatment_choice_json_schema,
treatment_choice_from_list_json_schema,
detailed_unstructured_treatment_action_text,
detailed_unstructured_tagging_action_text

)

log = logging.getLogger(__name__)
JSON_HIGHLIGHTER = JSONHighlighter()


class OutlinesTransformersADM(ActionBasedADM):
def __init__(self,
model_name,
device='auto',
baseline=False,
sampler=MultinomialSampler(),
**kwargs):
def __init__(
self,
model_name,
device='auto',
baseline=False,
mode='eval',
sampler=MultinomialSampler(),
**kwargs
):
self.baseline = baseline

model_kwargs = kwargs.get('model_kwargs', {})
self.mode = mode
if 'precision' in kwargs:
if kwargs['precision'] == 'half':
torch_dtype = torch.float16
Expand All @@ -87,6 +90,8 @@ def __init__(self,
# the sampler itself (which defaults to 1); setting the number
# of samples in the sampler may result in unexpected behavior
self.sampler = sampler
# Edited prompt from the Demo interface
self._system_ui_prompt = None

def dialog_to_prompt(self, dialog):
tokenizer = self.model.tokenizer.tokenizer
Expand Down Expand Up @@ -152,25 +157,35 @@ def batched(cls, iterable, n):
yield batch

@classmethod
def run_in_batches(cls, inference_function, inputs, batch_size):
def run_in_batches(cls, inference_function, inputs, batch_size, max_tokens, seed):
''' Batch inference to avoid out of memory error'''
outputs = []
for batch in cls.batched(inputs, batch_size):
output = inference_function(list(batch))
output = inference_function(
list(batch),
max_tokens=max_tokens,
rng=torch.cuda.manual_seed(seed)
)
if not isinstance(output, list):
output = [output]
outputs.extend(output)
return outputs

def top_level_choose_action(self,
scenario_state,
available_actions,
alignment_target,
num_positive_samples=1,
num_negative_samples=0,
generator_batch_size=5,
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
**kwargs):
@property
def system_ui_prompt(self) -> str:
return self._system_ui_prompt

@system_ui_prompt.setter
def system_ui_prompt(self, edited_system_prompt: str):
self._system_ui_prompt = edited_system_prompt

def get_dialog_texts(self, scenario_state,
available_actions,
alignment_target,
num_positive_samples=1,
num_negative_samples=0,
kdma_descriptions_map="align_system/prompt_engineering/kdma_descriptions.yml",
**kwargs):
if self.baseline and num_negative_samples > 0:
raise RuntimeError("No notion of negative samples for baseline run")
if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
Expand All @@ -185,6 +200,8 @@ def top_level_choose_action(self,
available_actions,
scenario_state
)
# Sort the choices
choices = sorted(choices)

positive_icl_examples = []
negative_icl_examples = []
Expand Down Expand Up @@ -224,8 +241,11 @@ def top_level_choose_action(self,

# Create positive ICL example generators
positive_target = {'kdma': kdma, 'name': name, 'value': value}
positive_icl_example_generator = incontext_utils.BaselineIncontextExampleGenerator(incontext_settings,
[positive_target])
positive_icl_example_generator = (
incontext_utils.BaselineIncontextExampleGenerator(
incontext_settings, [positive_target]
)
)
# Get subset of relevant of examples
positive_selected_icl_examples = positive_icl_example_generator.select_icl_examples(
sys_kdma_name=kdma,
Expand All @@ -244,8 +264,11 @@ def top_level_choose_action(self,
if num_negative_samples > 0:
# Create negative ICL example generators
negative_target = {'kdma': kdma, 'name': name, 'value': negative_value}
negative_icl_example_generator = incontext_utils.BaselineIncontextExampleGenerator(incontext_settings,
[negative_target])
negative_icl_example_generator = (
incontext_utils.BaselineIncontextExampleGenerator(
incontext_settings, [negative_target]
)
)
# Get subset of relevant of examples
negative_selected_icl_examples = negative_icl_example_generator.select_icl_examples(
sys_kdma_name=kdma,
Expand All @@ -265,10 +288,13 @@ def top_level_choose_action(self,
if "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
raise RuntimeError("No notion of incontext examples for baseline run")

shuffled_choices = choices
positive_dialogs = []
for _ in range(num_positive_samples):
shuffled_choices = random.sample(choices, len(choices))

if kwargs["demo_kwargs"]["shuffle_choices"]:
shuffled_choices = random.sample(choices, len(choices))
if self.system_ui_prompt is not None and self.mode == "demo":
positive_system_prompt = self.system_ui_prompt
prompt = action_selection_prompt(scenario_description, shuffled_choices)
dialog = [{'role': 'system', 'content': positive_system_prompt}]
dialog.extend(positive_icl_examples)
Expand All @@ -278,8 +304,8 @@ def top_level_choose_action(self,

negative_dialogs = []
for _ in range(num_negative_samples):
shuffled_choices = random.sample(choices, len(choices))

if kwargs["demo_kwargs"]["shuffle_choices"]:
shuffled_choices = random.sample(choices, len(choices))
prompt = action_selection_prompt(scenario_description, shuffled_choices)
dialog = [{'role': 'system', 'content': negative_system_prompt}]
dialog.extend(negative_icl_examples)
Expand All @@ -290,20 +316,66 @@ def top_level_choose_action(self,
# Need to set the whitespace_pattern to prevent the state
# machine from looping indefinitely in some cases, see:
# https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
generator = outlines.generate.json(
self.model,
action_choice_json_schema(json.dumps(choices)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

dialog_texts = [self.dialog_to_prompt(d) for d in
itertools.chain(positive_dialogs, negative_dialogs)]
dialog_texts = [
self.dialog_to_prompt(d)
for d in itertools.chain(positive_dialogs, negative_dialogs)
]

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info("[bold]*DIALOG PROMPT*[/bold]", extra={"markup": True})
log.info(dialog_texts[0])

responses = self.run_in_batches(generator, dialog_texts, generator_batch_size)
return dialog_texts, positive_dialogs

def top_level_choose_action(self,
scenario_state,
available_actions,
alignment_target,
num_positive_samples=1,
num_negative_samples=0,
generator_batch_size=5,
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
**kwargs):
if 'demo_kwargs' not in kwargs and self.mode == "demo":
raise ValueError('Demo configuration missing')

demo_kwargs = kwargs["demo_kwargs"]
choices = adm_utils.format_choices(
[a.unstructured for a in available_actions],
available_actions,
scenario_state,
)

dialog_texts, positive_dialogs = self.get_dialog_texts(
scenario_state,
available_actions,
alignment_target,
num_positive_samples=num_positive_samples,
num_negative_samples=num_negative_samples,
generator_batch_size=generator_batch_size,
kdma_descriptions_map=kdma_descriptions_map,
**kwargs
)
if self.mode == "eval":
generator = outlines.generate.json(
self.model,
action_choice_json_schema(json.dumps(choices)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")
else:
generator = outlines.generate.json(
self.model,
action_choice_json_schema_untrimmed(json.dumps(choices)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?",
)
responses = self.run_in_batches(
inference_function=generator,
inputs=dialog_texts,
batch_size=generator_batch_size,
max_tokens=int(demo_kwargs["max_generator_tokens"]),
seed=int(demo_kwargs["generator_seed"]),
)
positive_responses_choices =\
[r['action_choice'] for r in
responses[0:num_positive_samples]]
Expand Down
15 changes: 15 additions & 0 deletions align_system/prompt_engineering/outlines_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,21 @@ def action_choice_json_schema(choices_json_str):
'''


@outlines.prompt
def action_choice_json_schema_untrimmed(choices_json_str):
'''
{"$defs": {"ActionChoice": {"enum": {{ choices_json_str }},
"title": "ActionChoice",
"type": "string"}},
"properties": {"detailed_reasoning": {"title": "Detailed Reasoning",
"type": "string", "minLength": 1},
"action_choice": {"$ref": "#/$defs/ActionChoice"}},
"required": ["detailed_reasoning", "action_choice"],
"title": "ActionSelection",
"type": "object"}
'''


@outlines.prompt
def character_choice_json_schema(choices_json_str):
'''
Expand Down