Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions algorithms/delphi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from transformers import T5Config, T5ForConditionalGeneration, T5Tokenizer
import torch
import os

class Delphi:
def __init__(self, device="cuda", model="large"): # device parameter instead of device_id
self.device = torch.device(device) # use the provided device
print(f"Delphi device: {self.device}", model)

if model == "large":
self.MODEL_LOCATION = "../models/delphi-large"
self.MODEL_BASE = "t5-large"

elif model == "11b":
self.MODEL_LOCATION = "../models/delphi-11b"
self.MODEL_BASE = "t5-11b"

else:
raise ValueError("Model should be either 'large' or '11b'")

# verify that the model exists
if not os.path.exists(self.MODEL_LOCATION):
raise ValueError("Model file does not exist: {}".format(self.MODEL_LOCATION))

self.load_model()

def load_model(self):
self.model = T5ForConditionalGeneration.from_pretrained(self.MODEL_LOCATION)
self.model.to(self.device)
self.tokenizer = T5Tokenizer.from_pretrained(self.MODEL_BASE, model_max_length=512)

def run_inference(self, input_string):
input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids
outputs = self.model.generate(input_ids, max_length=200)

decoded_outputs = self.tokenizer.decode(outputs[0])

return decoded_outputs
6 changes: 6 additions & 0 deletions baseline_system_local_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

from algorithms.llm_baseline import LLMBaseline
from algorithms.llama_index import LlamaIndex
from algorithms.delphi import Delphi
from utils.enums import ProbeType
from prompt_engineering.common import build_casualties_string, prepare_prompt
from similarity_measures.bert import force_choice_with_bert




def main():
parser = argparse.ArgumentParser(
description="Simple LLM baseline system running against local files")
Expand Down Expand Up @@ -95,6 +98,9 @@ def run_baseline_system_local_filepath(
algorithm = LlamaIndex(
device="cuda", model_name=model,
**algorithm_kwargs_parsed)
elif algorithm == "delphi":
algorithm = Delphi(
device="cuda", model=model)

algorithm.load_model()

Expand Down