diff --git a/algorithms/delphi.py b/algorithms/delphi.py new file mode 100644 index 00000000..21c247f4 --- /dev/null +++ b/algorithms/delphi.py @@ -0,0 +1,38 @@ +from transformers import T5Config, T5ForConditionalGeneration, T5Tokenizer +import torch +import os + +class Delphi: + def __init__(self, device="cuda", model="large"): # device parameter instead of device_id + self.device = torch.device(device) # use the provided device + print(f"Delphi device: {self.device}", model) + + if model == "large": + self.MODEL_LOCATION = "../models/delphi-large" + self.MODEL_BASE = "t5-large" + + elif model == "11b": + self.MODEL_LOCATION = "../models/delphi-11b" + self.MODEL_BASE = "t5-11b" + + else: + raise ValueError("Model should be either 'large' or '11b'") + + # verify that the model exists + if not os.path.exists(self.MODEL_LOCATION): + raise ValueError("Model file does not exist: {}".format(self.MODEL_LOCATION)) + + self.load_model() + + def load_model(self): + self.model = T5ForConditionalGeneration.from_pretrained(self.MODEL_LOCATION) + self.model.to(self.device) + self.tokenizer = T5Tokenizer.from_pretrained(self.MODEL_BASE, model_max_length=512) + + def run_inference(self, input_string): + input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids + outputs = self.model.generate(input_ids, max_length=200) + + decoded_outputs = self.tokenizer.decode(outputs[0]) + + return decoded_outputs \ No newline at end of file diff --git a/baseline_system_local_files.py b/baseline_system_local_files.py index ec55a3db..93ed58b5 100644 --- a/baseline_system_local_files.py +++ b/baseline_system_local_files.py @@ -4,11 +4,14 @@ from algorithms.llm_baseline import LLMBaseline from algorithms.llama_index import LlamaIndex +from algorithms.delphi import Delphi from utils.enums import ProbeType from prompt_engineering.common import build_casualties_string, prepare_prompt from similarity_measures.bert import force_choice_with_bert + + def main(): parser = argparse.ArgumentParser( description="Simple LLM baseline system running against local files") @@ -95,6 +98,9 @@ def run_baseline_system_local_filepath( algorithm = LlamaIndex( device="cuda", model_name=model, **algorithm_kwargs_parsed) + elif algorithm == "delphi": + algorithm = Delphi( + device="cuda", model=model) algorithm.load_model()