diff --git a/align_system/algorithms/outlines_adm.py b/align_system/algorithms/outlines_adm.py index 9dda92b9..9694d7b8 100644 --- a/align_system/algorithms/outlines_adm.py +++ b/align_system/algorithms/outlines_adm.py @@ -25,7 +25,6 @@ calculate_votes, filter_votes_to_responses, ) -from align_system.utils.hydrate_state import hydrate_scenario_state from align_system.algorithms.abstracts import ActionBasedADM from align_system.prompt_engineering.outlines_prompts import ( baseline_system_prompt, @@ -41,13 +40,13 @@ followup_clarify_treatment_from_list, followup_clarify_tag, action_choice_json_schema, + action_choice_json_schema_untrimmed, aid_choice_json_schema, character_choice_json_schema, tag_choice_json_schema, treatment_choice_json_schema, treatment_choice_from_list_json_schema, - detailed_unstructured_treatment_action_text, - detailed_unstructured_tagging_action_text + ) log = logging.getLogger(__name__) @@ -55,15 +54,19 @@ class OutlinesTransformersADM(ActionBasedADM): - def __init__(self, - model_name, - device='auto', - baseline=False, - sampler=MultinomialSampler(), - **kwargs): + def __init__( + self, + model_name, + device='auto', + baseline=False, + mode='eval', + sampler=MultinomialSampler(), + **kwargs + ): self.baseline = baseline model_kwargs = kwargs.get('model_kwargs', {}) + self.mode = mode if 'precision' in kwargs: if kwargs['precision'] == 'half': torch_dtype = torch.float16 @@ -87,6 +90,8 @@ def __init__(self, # the sampler itself (which defaults to 1); setting the number # of samples in the sampler may result in unexpected behavior self.sampler = sampler + # Edited prompt from the Demo interface + self._system_ui_prompt = None def dialog_to_prompt(self, dialog): tokenizer = self.model.tokenizer.tokenizer @@ -152,25 +157,35 @@ def batched(cls, iterable, n): yield batch @classmethod - def run_in_batches(cls, inference_function, inputs, batch_size): + def run_in_batches(cls, inference_function, inputs, batch_size, max_tokens, seed): ''' Batch inference to avoid out of memory error''' outputs = [] for batch in cls.batched(inputs, batch_size): - output = inference_function(list(batch)) + output = inference_function( + list(batch), + max_tokens=max_tokens, + rng=torch.cuda.manual_seed(seed) + ) if not isinstance(output, list): output = [output] outputs.extend(output) return outputs - def top_level_choose_action(self, - scenario_state, - available_actions, - alignment_target, - num_positive_samples=1, - num_negative_samples=0, - generator_batch_size=5, - kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml', - **kwargs): + @property + def system_ui_prompt(self) -> str: + return self._system_ui_prompt + + @system_ui_prompt.setter + def system_ui_prompt(self, edited_system_prompt: str): + self._system_ui_prompt = edited_system_prompt + + def get_dialog_texts(self, scenario_state, + available_actions, + alignment_target, + num_positive_samples=1, + num_negative_samples=0, + kdma_descriptions_map="align_system/prompt_engineering/kdma_descriptions.yml", + **kwargs): if self.baseline and num_negative_samples > 0: raise RuntimeError("No notion of negative samples for baseline run") if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0: @@ -185,6 +200,8 @@ def top_level_choose_action(self, available_actions, scenario_state ) + # Sort the choices + choices = sorted(choices) positive_icl_examples = [] negative_icl_examples = [] @@ -224,8 +241,11 @@ def top_level_choose_action(self, # Create positive ICL example generators positive_target = {'kdma': kdma, 'name': name, 'value': value} - positive_icl_example_generator = incontext_utils.BaselineIncontextExampleGenerator(incontext_settings, - [positive_target]) + positive_icl_example_generator = ( + incontext_utils.BaselineIncontextExampleGenerator( + incontext_settings, [positive_target] + ) + ) # Get subset of relevant of examples positive_selected_icl_examples = positive_icl_example_generator.select_icl_examples( sys_kdma_name=kdma, @@ -244,8 +264,11 @@ def top_level_choose_action(self, if num_negative_samples > 0: # Create negative ICL example generators negative_target = {'kdma': kdma, 'name': name, 'value': negative_value} - negative_icl_example_generator = incontext_utils.BaselineIncontextExampleGenerator(incontext_settings, - [negative_target]) + negative_icl_example_generator = ( + incontext_utils.BaselineIncontextExampleGenerator( + incontext_settings, [negative_target] + ) + ) # Get subset of relevant of examples negative_selected_icl_examples = negative_icl_example_generator.select_icl_examples( sys_kdma_name=kdma, @@ -265,10 +288,13 @@ def top_level_choose_action(self, if "incontext" in kwargs and kwargs["incontext"]["number"] > 0: raise RuntimeError("No notion of incontext examples for baseline run") + shuffled_choices = choices positive_dialogs = [] for _ in range(num_positive_samples): - shuffled_choices = random.sample(choices, len(choices)) - + if kwargs["demo_kwargs"]["shuffle_choices"]: + shuffled_choices = random.sample(choices, len(choices)) + if self.system_ui_prompt is not None and self.mode == "demo": + positive_system_prompt = self.system_ui_prompt prompt = action_selection_prompt(scenario_description, shuffled_choices) dialog = [{'role': 'system', 'content': positive_system_prompt}] dialog.extend(positive_icl_examples) @@ -278,8 +304,8 @@ def top_level_choose_action(self, negative_dialogs = [] for _ in range(num_negative_samples): - shuffled_choices = random.sample(choices, len(choices)) - + if kwargs["demo_kwargs"]["shuffle_choices"]: + shuffled_choices = random.sample(choices, len(choices)) prompt = action_selection_prompt(scenario_description, shuffled_choices) dialog = [{'role': 'system', 'content': negative_system_prompt}] dialog.extend(negative_icl_examples) @@ -290,20 +316,66 @@ def top_level_choose_action(self, # Need to set the whitespace_pattern to prevent the state # machine from looping indefinitely in some cases, see: # https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934 - generator = outlines.generate.json( - self.model, - action_choice_json_schema(json.dumps(choices)), - sampler=self.sampler, - whitespace_pattern=r"[ ]?") - dialog_texts = [self.dialog_to_prompt(d) for d in - itertools.chain(positive_dialogs, negative_dialogs)] + dialog_texts = [ + self.dialog_to_prompt(d) + for d in itertools.chain(positive_dialogs, negative_dialogs) + ] - log.info("[bold]*DIALOG PROMPT*[/bold]", - extra={"markup": True}) + log.info("[bold]*DIALOG PROMPT*[/bold]", extra={"markup": True}) log.info(dialog_texts[0]) - responses = self.run_in_batches(generator, dialog_texts, generator_batch_size) + return dialog_texts, positive_dialogs + + def top_level_choose_action(self, + scenario_state, + available_actions, + alignment_target, + num_positive_samples=1, + num_negative_samples=0, + generator_batch_size=5, + kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml', + **kwargs): + if 'demo_kwargs' not in kwargs and self.mode == "demo": + raise ValueError('Demo configuration missing') + + demo_kwargs = kwargs["demo_kwargs"] + choices = adm_utils.format_choices( + [a.unstructured for a in available_actions], + available_actions, + scenario_state, + ) + + dialog_texts, positive_dialogs = self.get_dialog_texts( + scenario_state, + available_actions, + alignment_target, + num_positive_samples=num_positive_samples, + num_negative_samples=num_negative_samples, + generator_batch_size=generator_batch_size, + kdma_descriptions_map=kdma_descriptions_map, + **kwargs + ) + if self.mode == "eval": + generator = outlines.generate.json( + self.model, + action_choice_json_schema(json.dumps(choices)), + sampler=self.sampler, + whitespace_pattern=r"[ ]?") + else: + generator = outlines.generate.json( + self.model, + action_choice_json_schema_untrimmed(json.dumps(choices)), + sampler=self.sampler, + whitespace_pattern=r"[ ]?", + ) + responses = self.run_in_batches( + inference_function=generator, + inputs=dialog_texts, + batch_size=generator_batch_size, + max_tokens=int(demo_kwargs["max_generator_tokens"]), + seed=int(demo_kwargs["generator_seed"]), + ) positive_responses_choices =\ [r['action_choice'] for r in responses[0:num_positive_samples]] diff --git a/align_system/prompt_engineering/outlines_prompts.py b/align_system/prompt_engineering/outlines_prompts.py index df112f45..69b48e61 100644 --- a/align_system/prompt_engineering/outlines_prompts.py +++ b/align_system/prompt_engineering/outlines_prompts.py @@ -347,6 +347,21 @@ def action_choice_json_schema(choices_json_str): ''' +@outlines.prompt +def action_choice_json_schema_untrimmed(choices_json_str): + ''' + {"$defs": {"ActionChoice": {"enum": {{ choices_json_str }}, + "title": "ActionChoice", + "type": "string"}}, + "properties": {"detailed_reasoning": {"title": "Detailed Reasoning", + "type": "string", "minLength": 1}, + "action_choice": {"$ref": "#/$defs/ActionChoice"}}, + "required": ["detailed_reasoning", "action_choice"], + "title": "ActionSelection", + "type": "object"} + ''' + + @outlines.prompt def character_choice_json_schema(choices_json_str): '''