diff --git a/align_system/algorithms/llama_2_single_kdma_adm.py b/align_system/algorithms/llama_2_single_kdma_adm.py index c3fdc12d..926b8d06 100644 --- a/align_system/algorithms/llama_2_single_kdma_adm.py +++ b/align_system/algorithms/llama_2_single_kdma_adm.py @@ -3,6 +3,7 @@ import random import os import pathlib +import random from align_system.algorithms.lib.aligned_decision_maker import AlignedDecisionMaker from jinja2.exceptions import TemplateError @@ -113,6 +114,7 @@ def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', prec self.hf_model = hf_model self.temperature = temperature self.chat_template = kwargs.get('chat_template', None) + self.dataset = [] assert precision in ['full', 'half'], "precision must be either 'full' or 'half'." self.precision = torch.float32 if precision == 'full' else torch.float16 @@ -124,11 +126,11 @@ def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', prec def load_model(self, model=None, tokenizer=None): assert (model is None) == (tokenizer is None), "model and tokenizer must both be None or both be not None." if model is not None: - print('Loading model and tokenizer from provided objects.') + log.info('Loading model and tokenizer from provided objects.') self.model = model self.tokenizer = tokenizer else: - print('Loading model:', self.hf_model) + log.info('Loading model: %s', self.hf_model) if self.device == 'auto': self.model = AutoModelForCausalLM.from_pretrained(self.hf_model, torch_dtype=self.precision, device_map='auto') else: @@ -282,7 +284,7 @@ def respond_to_dialog(self, dialog, prefix=None): else: new_dialog.append(message) dialog = new_dialog - print('INPUT\n', dialog) + log.info('INPUT\n %s', dialog) prompt_tokens = [self.tokenizer.apply_chat_template(dialog, tokenize=True)] inference_pair['input'] = self.tokenizer.apply_chat_template(dialog, tokenize=False) @@ -298,11 +300,11 @@ def respond_to_dialog(self, dialog, prefix=None): outputs = self.model.generate(prompt_tokens, return_dict_in_generate=True, output_scores=True, max_new_tokens=512, temperature=self.temperature, do_sample=True) - # Print the generated model output + # log.info the generated model output generated_output = self.tokenizer.decode(outputs.sequences[0][prompt_length:]) inference_pair['output'] = generated_output - print('INFERENCE PAIR\n', inference_pair) + log.info('INFERENCE PAIR\n %s', inference_pair) return generated_output, inference_pair @@ -402,6 +404,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam shuffled_choices, system_message=system_message) + if not logged_aligned_dialog: log.debug("[bold]*ALIGNED DIALOG*[/bold]", extra={"markup": True}) @@ -422,7 +425,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam if not good_parse: reasoning, answer_idx, parse_method = Llama2SingleKDMAADM.bert_similarity_parse(high_response, shuffled_choices) - print('CHOSEN ANSWER IDX', answer_idx, shuffled_choices) + log.explain('CHOSEN ANSWER IDX %s %s', answer_idx, shuffled_choices) assert answer_idx is not None, f'Failed to parse answer index from generated output: {low_response}' responses.append({ @@ -594,10 +597,10 @@ def parse_generated_output(generated_output, n_choices): @staticmethod def bert_similarity_parse(generated_output, choices): - print('BERT SIMILARITY PARSE') + log.info('BERT SIMILARITY PARSE') force_choice_func = build_force_choice_func('bert') answer_idx, _ = force_choice_func(generated_output, choices) - print('ANSWER IDX', answer_idx, type(answer_idx)) + log.info('ANSWER IDX %s %s', answer_idx, type(answer_idx)) return generated_output, answer_idx, 'bert_similarity' @staticmethod @@ -749,12 +752,65 @@ def run_aligned_decision_maker_with_voting( break return reasoning, answer_idx, responses, inference_pairs + + def format_single_incontext_prompt(self, sample, labels): + prompt = sample['scenario'] + if sample['state'] is not None: + prompt += f'\n{sample["state"]}' + + for choice, label in zip(sample['choices'],labels): + level = 'high' if list(label.values())[0] > 5 else 'low' + attribute = list(label.keys())[0].replace('_', ' ') + prompt += f' If you had a {level} {attribute}, you would select {choice}.' + + return prompt + + + #TODO: add prompt completetion here for choices as well. + def __call__(self, sample, target_kdma_values, **kwargs): + """ Build the prompt and send to the LLM to ask for a single KDMA + + + """ prompt = sample['scenario'] if sample['state'] is not None: prompt += f'\n{sample["state"]}' + if 'incontext' in kwargs: + possible_samples = [] + + #sam has both info in first element and labels in second element + for sam in kwargs['dataset']: + if sam[0]['probe_id'] != sample['probe_id']: + + possible_samples.append(sam) + + if len(possible_samples) < kwargs['incontext']['number']: + raise RuntimeError(f'Not enough possible incontext samples to learn from here.' + f'Only {len(possible_samples)} samples while asking for' + f'{kwargs["incontext"]["number"]} in context samples') + + if kwargs['incontext']['method'] == 'random': + chosen_sample = random.sample(possible_samples, kwargs['incontext']['number']) + else: + raise(f'"{kwargs["incontext"]["method"]}" is not a valid incontext method. Please use "random", ') + + incontext_prompt_start = ' Here are some examples of similar problems with their attributes. ' + + + extra_prompts = [incontext_prompt_start] + ci = 1 + for cs, cl in chosen_sample: + extra_prompts.append(f' Example {ci}' + self.format_single_incontext_prompt(cs, cl)) + ci += 1 + + extra_prompts.append(' Given these similar examples, please answer the question for the following scenario. ') + + extra_prompts = ''.join(extra_prompts) + prompt = extra_prompts + prompt + if 'retriever' in kwargs: # retriever_prompt = "How would you treat the following injuries: {}".format(prompt) retriever_prompt = "{} {}".format(prompt, sample['probe']) diff --git a/align_system/evaluation/adm_evaluator.py b/align_system/evaluation/adm_evaluator.py index ff7eda9f..4a742e4f 100644 --- a/align_system/evaluation/adm_evaluator.py +++ b/align_system/evaluation/adm_evaluator.py @@ -11,7 +11,7 @@ def generate_outputs(dataset, adm, target_kdma_values, **kwargs): }) continue - outputs.append(adm(input_, target_kdma_values, labels=label, **kwargs)) + outputs.append(adm(input_, target_kdma_values, labels=label, dataset=dataset, **kwargs)) return outputs