-
Notifications
You must be signed in to change notification settings - Fork 5
Factor get_dialogs to static method in outlines_adm #172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,3 +3,5 @@ run.bash | |
| venv/ | ||
| __pycache__/ | ||
| outputs | ||
|
|
||
| .vscode/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,10 @@ | ||
| import json | ||
| import random | ||
| import itertools | ||
| import numpy as np | ||
| import torch | ||
| import yaml | ||
| import copy | ||
| from functools import partial | ||
|
|
||
| import outlines | ||
| from outlines.samplers import MultinomialSampler | ||
|
|
@@ -25,7 +25,6 @@ | |
| calculate_votes, | ||
| filter_votes_to_responses, | ||
| ) | ||
| from align_system.utils.hydrate_state import hydrate_scenario_state | ||
| from align_system.algorithms.abstracts import ActionBasedADM | ||
| from align_system.prompt_engineering.outlines_prompts import ( | ||
| baseline_system_prompt, | ||
|
|
@@ -46,8 +45,6 @@ | |
| tag_choice_json_schema, | ||
| treatment_choice_json_schema, | ||
| treatment_choice_from_list_json_schema, | ||
| detailed_unstructured_treatment_action_text, | ||
| detailed_unstructured_tagging_action_text, | ||
| high_risk_aversion_system_prompt, | ||
| low_risk_aversion_system_prompt, | ||
| high_continuing_care_system_prompt, | ||
|
|
@@ -195,7 +192,8 @@ def kdma_value_to_system_prompt(kdma, value): | |
| else: | ||
| return None | ||
|
|
||
| def _state_to_top_level_prompt(self, scenario_state, actions): | ||
| @staticmethod | ||
| def _static_state_to_top_level_prompt(action_selection_prompt_template, scenario_description, scenario_state, actions): | ||
| """ | ||
| Generate prompt dialog based on given state and actions | ||
| """ | ||
|
|
@@ -205,11 +203,23 @@ def _state_to_top_level_prompt(self, scenario_state, actions): | |
| scenario_state | ||
| ) | ||
|
|
||
| scenario_description = self.scenario_description_template(scenario_state) | ||
| prompt = self.action_selection_prompt_template(scenario_description, choices) | ||
| prompt = action_selection_prompt_template(scenario_description, choices) | ||
|
|
||
| return prompt, choices | ||
|
|
||
| def _state_to_top_level_prompt(self, scenario_state, actions): | ||
| """ | ||
| Generate prompt dialog based on given state and actions | ||
| """ | ||
| scenario_description = self.scenario_description_template(scenario_state) | ||
| return OutlinesTransformersADM._static_state_to_top_level_prompt( | ||
| self.action_selection_prompt_template, | ||
| scenario_description, | ||
| scenario_state, | ||
| actions | ||
| ) | ||
|
|
||
|
|
||
| # Function borrowed from | ||
| # https://docs.python.org/3/library/itertools.html#itertools.batched | ||
| # (since itertools.batched is only available in Python 3.12 or newer): | ||
|
|
@@ -233,21 +243,25 @@ def run_in_batches(cls, inference_function, inputs, batch_size): | |
| outputs.extend(output) | ||
| return outputs | ||
|
|
||
| def top_level_choose_action(self, | ||
| scenario_state, | ||
| available_actions, | ||
| alignment_target, | ||
| num_positive_samples=1, | ||
| num_negative_samples=0, | ||
| generator_batch_size=5, | ||
| kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml', | ||
| **kwargs): | ||
| if self.baseline and num_negative_samples > 0: | ||
| @staticmethod | ||
| def get_dialogs(scenario_state, | ||
| available_actions, | ||
| alignment_target, | ||
| num_positive_samples=1, | ||
| num_negative_samples=0, | ||
| kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml', | ||
| shuffle_choices=True, | ||
| baseline=False, | ||
| scenario_description_template=scenario_state_description_1, | ||
| action_selection_prompt_template=action_selection_prompt, | ||
| baseline_system_prompt=baseline_system_prompt, | ||
| **kwargs): | ||
| if baseline and num_negative_samples > 0: | ||
| raise RuntimeError("No notion of negative samples for baseline run") | ||
| if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0: | ||
| if baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0: | ||
| raise RuntimeError("No notion of incontext examples for baseline run") | ||
|
|
||
| scenario_description = self.scenario_description_template(scenario_state) | ||
| scenario_description = scenario_description_template(scenario_state) | ||
| # Important that the choices stay in the same order as the | ||
| # available actions as we'll use the selected index later to | ||
| # map to the corresponding action | ||
|
|
@@ -260,12 +274,11 @@ def top_level_choose_action(self, | |
| positive_icl_examples = [] | ||
| negative_icl_examples = [] | ||
| incontext_settings=kwargs.get("incontext", {}) | ||
| if not self.baseline and alignment_target is not None: | ||
| kdma_values = alignment_target.kdma_values | ||
|
|
||
| if not baseline and alignment_target is not None: | ||
| kdma_values = alignment_target.kdma_values | ||
| if len(kdma_values) != 1: | ||
| raise RuntimeError("This ADM assumes a single KDMA target, aborting!") | ||
|
|
||
| kdma_value = kdma_values[0] | ||
| if isinstance(kdma_value, KDMAValue): | ||
| kdma_value = kdma_value.to_dict() | ||
|
|
@@ -279,8 +292,8 @@ def top_level_choose_action(self, | |
| kdma_descriptions = yaml.load(f, Loader=yaml.FullLoader) | ||
| name = kdma_descriptions[kdma]['name'] | ||
|
|
||
| positive_system_prompt = self.__class__.kdma_value_to_system_prompt(kdma, value) | ||
| negative_system_prompt = self.__class__.kdma_value_to_system_prompt(kdma, negative_value) | ||
| positive_system_prompt = OutlinesTransformersADM.kdma_value_to_system_prompt(kdma, value) | ||
| negative_system_prompt = OutlinesTransformersADM.kdma_value_to_system_prompt(kdma, negative_value) | ||
|
|
||
| if positive_system_prompt is None: | ||
| raise RuntimeError("Couldn't find system prompt for kdma: {}, and " | ||
|
|
@@ -290,8 +303,7 @@ def top_level_choose_action(self, | |
| "value: {}.".format(kdma, negative_value)) | ||
|
|
||
| if "incontext" in kwargs and "number" in incontext_settings and incontext_settings["number"] > 0: | ||
| scenario_to_match = self.scenario_description_template(scenario_state) | ||
| prompt_to_match, _ = self._state_to_top_level_prompt(scenario_state, available_actions) | ||
| prompt_to_match, _ = OutlinesTransformersADM._state_to_top_level_prompt(action_selection_prompt_template, scenario_state, available_actions) | ||
|
|
||
| # Create positive ICL example generators | ||
| positive_target = {'kdma': kdma, 'name': name, 'value': value} | ||
|
|
@@ -300,7 +312,7 @@ def top_level_choose_action(self, | |
| # Get subset of relevant of examples | ||
| positive_selected_icl_examples = positive_icl_example_generator.select_icl_examples( | ||
| sys_kdma_name=kdma, | ||
| scenario_description_to_match=scenario_to_match, | ||
| scenario_description_to_match=scenario_description, | ||
| prompt_to_match=prompt_to_match, | ||
| state_comparison=scenario_state | ||
| ) | ||
|
|
@@ -320,7 +332,7 @@ def top_level_choose_action(self, | |
| # Get subset of relevant of examples | ||
| negative_selected_icl_examples = negative_icl_example_generator.select_icl_examples( | ||
| sys_kdma_name=kdma, | ||
| scenario_description_to_match=scenario_to_match, | ||
| scenario_description_to_match=scenario_description, | ||
| prompt_to_match=prompt_to_match, | ||
| state_comparison=scenario_state | ||
| ) | ||
|
|
@@ -330,17 +342,17 @@ def top_level_choose_action(self, | |
| {"role": "assistant", "content": f'{icl_sample["response"]}'} | ||
| ]) | ||
| else: | ||
| positive_system_prompt = self.baseline_system_prompt() | ||
| positive_system_prompt = baseline_system_prompt() | ||
| if num_negative_samples > 0: | ||
| raise RuntimeError("No notion of negative samples for baseline run") | ||
| if "incontext" in kwargs and kwargs["incontext"]["number"] > 0: | ||
| raise RuntimeError("No notion of incontext examples for baseline run") | ||
| negative_system_prompt = None # Not used in baseline | ||
|
|
||
| positive_dialogs = [] | ||
| for _ in range(num_positive_samples): | ||
| shuffled_choices = random.sample(choices, len(choices)) | ||
|
|
||
| prompt = self.action_selection_prompt_template(scenario_description, shuffled_choices) | ||
| shuf = random.sample(choices, len(choices)) if shuffle_choices else choices | ||
| prompt = action_selection_prompt(scenario_description, shuf) | ||
| dialog = [{'role': 'system', 'content': positive_system_prompt}] | ||
| dialog.extend(positive_icl_examples) | ||
| dialog.append({'role': 'user', 'content': prompt}) | ||
|
|
@@ -349,24 +361,73 @@ def top_level_choose_action(self, | |
|
|
||
| negative_dialogs = [] | ||
| for _ in range(num_negative_samples): | ||
| shuffled_choices = random.sample(choices, len(choices)) | ||
|
|
||
| prompt = self.action_selection_prompt_template(scenario_description, shuffled_choices) | ||
| shuf = random.sample(choices, len(choices)) if shuffle_choices else choices | ||
| prompt = action_selection_prompt(scenario_description, shuf) | ||
| dialog = [{'role': 'system', 'content': negative_system_prompt}] | ||
| dialog.extend(negative_icl_examples) | ||
| dialog.append({'role': 'user', 'content': prompt}) | ||
|
|
||
| negative_dialogs.append(dialog) | ||
|
|
||
| return {"scenario_description": scenario_description, | ||
| "choices": choices, | ||
| "positive_system_prompt": positive_system_prompt, | ||
| "negative_system_prompt": negative_system_prompt, | ||
| "positive_dialogs": positive_dialogs, | ||
| "negative_dialogs": negative_dialogs} | ||
|
|
||
| def top_level_choose_action(self, | ||
| scenario_state, | ||
| available_actions, | ||
| alignment_target, | ||
| num_positive_samples=1, | ||
| num_negative_samples=0, | ||
| generator_batch_size=5, | ||
| kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml', | ||
| reasoning_max_length=512, | ||
| generator_seed=-1, | ||
| max_generator_tokens=-1, | ||
| shuffle_choices=True, | ||
| **kwargs): | ||
| if self.baseline and num_negative_samples > 0: | ||
| raise RuntimeError("No notion of negative samples for baseline run") | ||
| if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0: | ||
| raise RuntimeError("No notion of incontext examples for baseline run") | ||
|
|
||
| dialogs_data = OutlinesTransformersADM.get_dialogs( | ||
| scenario_state, | ||
| available_actions, | ||
| alignment_target, | ||
| num_positive_samples, | ||
| num_negative_samples, | ||
| kdma_descriptions_map, | ||
| shuffle_choices, | ||
| baseline=self.baseline, | ||
| scenario_description_template=self.scenario_description_template, | ||
| action_selection_prompt_template=self.action_selection_prompt_template, | ||
| baseline_system_prompt=self.baseline_system_prompt, | ||
| ) | ||
| choices = dialogs_data["choices"] | ||
| positive_dialogs = dialogs_data["positive_dialogs"] | ||
| negative_dialogs = dialogs_data["negative_dialogs"] | ||
|
|
||
| # Need to set the whitespace_pattern to prevent the state | ||
| # machine from looping indefinitely in some cases, see: | ||
| # https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934 | ||
| generator = outlines.generate.json( | ||
| self.model, | ||
| action_choice_json_schema(json.dumps(choices)), | ||
| action_choice_json_schema(json.dumps(choices), reasoning_max_length), | ||
| sampler=self.sampler, | ||
| whitespace_pattern=r"[ ]?") | ||
|
|
||
| if max_generator_tokens >= 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an interesting approach, normally I would set the default in the kwargs to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The -1 is a gross hack, and I regret it already =)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't say it's a hack, just curious about the tradeoffs; I'm content to leave it as is. |
||
| generator = partial(generator, max_tokens=max_generator_tokens) | ||
|
|
||
| if generator_seed >= 0: | ||
| torch.manual_seed(generator_seed) | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.manual_seed(generator_seed) | ||
|
|
||
|
|
||
| dialog_texts = [self.dialog_to_prompt(d) for d in | ||
| itertools.chain(positive_dialogs, negative_dialogs)] | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious why you prefer
OutlinesTransformersADMoverself.__class__here (is the former just less obtuse but does the same thing? Or is there some case where one approach will break). Not asking for a change here just want to pick your brainThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made that function static, so no self handy. Mabye there is some snakey black magic to get a
__class__in static func?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I missed that bit (not sure about getting
__class__in a static method) let's leave as is