Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions align_system/algorithms/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ def choose_action(self,

class StructuredInferenceEngine(ABC):
@abstractmethod
def dialog_to_prompt(dialog: list[dict]) -> str:
def dialog_to_prompt(self, dialog: list[dict]) -> str:
pass

@abstractmethod
def run_inference(prompts: Union[str, list[str]],
def run_inference(self,
prompts: Union[str, list[str]],
schema: str) -> Union[dict, list[dict]]:
pass

Expand Down
105 changes: 61 additions & 44 deletions align_system/algorithms/outlines_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import partial

import outlines
from outlines.samplers import MultinomialSampler
from outlines.types import JsonSchema
import jinja2
from rich.highlighter import JSONHighlighter
from align_system.data_models.compat.ta3_ph1_client_models import (
Expand All @@ -16,6 +16,7 @@
CharacterTagEnum,
KDMAValue
)
import transformers

from align_system.utils import logging
from align_system.utils import adm_utils
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(self,
model_name,
device='auto',
baseline=False,
sampler=MultinomialSampler(),
generation_kwargs=None,
scenario_description_template=scenario_state_description_1,
action_selection_prompt_template=action_selection_prompt,
baseline_system_prompt=baseline_system_prompt,
Expand All @@ -86,19 +87,21 @@ def __init__(self,
f"Unexpected value for 'precision' ({kwargs['precision']})"
", expecting either 'half' or 'full'")

model_kwargs['torch_dtype'] = torch_dtype
model_kwargs['dtype'] = torch_dtype

self.model = outlines.models.transformers(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=kwargs.get('tokenizer_kwargs', {}))
# NOTE: In cases where we want multiple samples, we're passing
# in a list of prompts (this allows us to shuffle answers in
# each prompt), rather than setting the number of samples in
# the sampler itself (which defaults to 1); setting the number
# of samples in the sampler may result in unexpected behavior
self.sampler = sampler
self.model = outlines.from_transformers(
transformers.AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs, device_map=device),
transformers.AutoTokenizer.from_pretrained(model_name, **kwargs.get('tokenizer_kwargs', {})),
device_dtype=torch_dtype)

if generation_kwargs is None:
generation_kwargs = {'temperature': 0.7}
self.generation_kwargs = generation_kwargs

# Sometimes the internal default for outlines/transformers is 20,
# leading to very short (and often invalid JSON) outputs. Setting a
# somewhat generous default.
self.generation_kwargs.setdefault('max_new_tokens', 8192)

self.outlines_seed = outlines_seed
if self.outlines_seed is None:
Expand Down Expand Up @@ -240,15 +243,11 @@ def batched(cls, iterable, n):
yield batch

@classmethod
def run_in_batches(cls, inference_function, inputs, batch_size, rng=None):
def run_in_batches(cls, inference_function, inputs, batch_size, **generation_kwargs):
''' Batch inference to avoid out of memory error'''
outputs = []
for batch in cls.batched(inputs, batch_size):
if rng is None:
output = inference_function(list(batch))
else:
output = inference_function(list(batch), rng=rng)

output = inference_function(list(batch), **generation_kwargs)
if not isinstance(output, list):
output = [output]
outputs.extend(output)
Expand Down Expand Up @@ -432,12 +431,14 @@ def top_level_choose_action(self,
# Need to set the whitespace_pattern to prevent the state
# machine from looping indefinitely in some cases, see:
# https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
action_choice_json_schema(json.dumps(choices), reasoning_max_length),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

if max_generator_tokens >= 0:
generator = partial(generator, max_tokens=max_generator_tokens)

Expand All @@ -454,7 +455,13 @@ def top_level_choose_action(self,
extra={"markup": True})
log.info(dialog_texts[0])

responses = self.run_in_batches(generator, dialog_texts, generator_batch_size, rng=self.outlines_rng)
responses = self.run_in_batches(generator.batch,
dialog_texts,
generator_batch_size,
rng=self.outlines_rng,
**self.generation_kwargs)
responses = [json.loads(r) for r in responses]

positive_responses_choices =\
[r['action_choice'] for r in
responses[0:num_positive_samples]]
Expand Down Expand Up @@ -657,17 +664,19 @@ def ensure_character_id_is_populated(self,

character_names = [c.name for c in characters]

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
character_choice_json_schema(json.dumps(character_names)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_text)

selected_character = generator(dialog_text)
selected_character = json.loads(generator(dialog_text, **self.generation_kwargs))
selected_character_idx = character_names.index(selected_character['character_choice'])

log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
Expand Down Expand Up @@ -727,19 +736,21 @@ def populate_treatment_parameters(self,

dialog_text = self.dialog_to_prompt(dialog)

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
treatment_choice_json_schema(
json.dumps([s.type for s in available_supplies]),
json.dumps(valid_treatment_locations)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_text)

selected_treatment = generator(dialog_text)
selected_treatment = json.loads(generator(dialog_text, **self.generation_kwargs))

log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
extra={"markup": True})
Expand Down Expand Up @@ -799,14 +810,16 @@ def select_treatment_parameters(self,
extra={"markup": True})
log.info(dialog_text)

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
treatment_choice_from_list_json_schema(
json.dumps(possible_treatments)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

selected_treatment = generator(dialog_text)
generator = outlines.Generator(
self.model,
json_schema)

selected_treatment = json.loads(generator(dialog_text, **self.generation_kwargs))
log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
extra={"markup": True})
log.info(selected_treatment, extra={"highlighter": JSON_HIGHLIGHTER})
Expand Down Expand Up @@ -843,18 +856,20 @@ def populate_tagging_parameters(self,

dialog_text = self.dialog_to_prompt(dialog)

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
tag_choice_json_schema(
json.dumps(valid_tags)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_text)

selected_tag = generator(dialog_text)
selected_tag = json.loads(generator(dialog_text, **self.generation_kwargs))

log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
extra={"markup": True})
Expand Down Expand Up @@ -906,18 +921,20 @@ def populate_aid_parameters(self,

dialog_text = self.dialog_to_prompt(dialog)

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
aid_choice_json_schema(
json.dumps([aid.id for aid in available_aids])),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_text)

selected_aid = generator(dialog_text)
selected_aid = json.loads(generator(dialog_text, **self.generation_kwargs))

log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
extra={"markup": True})
Expand Down
80 changes: 45 additions & 35 deletions align_system/algorithms/outlines_inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,46 @@
import itertools
from collections.abc import Iterable
from textwrap import dedent
import json

import transformers
import outlines
from outlines.samplers import MultinomialSampler
from outlines.types import JsonSchema
import jinja2
import torch

from align_system.algorithms.abstracts import StructuredInferenceEngine

# Sometimes the internal default for outlines/transformers is 20,
# leading to very short (and often invalid JSON) outputs. Setting a
# somewhat generous default.
DEFAULT_MAX_GENERATOR_TOKENS=8192

class OutlinesTransformersInferenceEngine(StructuredInferenceEngine):
def __init__(self,
model_name,
device='auto',
precision='full',
max_generator_tokens=None,
sampler=MultinomialSampler(),
max_generator_tokens=DEFAULT_MAX_GENERATOR_TOKENS,
inference_batch_size=5,
model_kwargs={},
tokenizer_kwargs={}):
generation_kwargs=None,
model_kwargs=None,
tokenizer_kwargs=None):
self.model_name = model_name
self.precision = precision
self.inference_batch_size = inference_batch_size

if model_kwargs is None:
model_kwargs = {}
self.model_kwargs = model_kwargs

if tokenizer_kwargs is None:
tokenizer_kwargs = {}
self.tokenizer_kwargs = tokenizer_kwargs

if generation_kwargs is None:
generation_kwargs = {}
self.generation_kwargs = generation_kwargs

self.max_generator_tokens = max_generator_tokens

if self.precision == 'half':
Expand All @@ -36,19 +52,12 @@ def __init__(self,
f"Unexpected value for 'precision' ({precision})"
", expecting either 'half' or 'full'")

self.model_kwargs['torch_dtype'] = torch_dtype
self.model_kwargs['dtype'] = torch_dtype

self.model = outlines.models.transformers(
self.model_name,
device=device,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs)
# NOTE: In cases where we want multiple samples, we're passing
# in a list of prompts (this allows us to shuffle answers in
# each prompt), rather than setting the number of samples in
# the sampler itself (which defaults to 1); setting the number
# of samples in the sampler may result in unexpected behavior
self.sampler = sampler
self.model = outlines.from_transformers(
transformers.AutoModelForCausalLM.from_pretrained(model_name, **self.model_kwargs, device_map='auto'),
transformers.AutoTokenizer.from_pretrained(model_name, **self.tokenizer_kwargs),
device_dtype=torch_dtype)

def dialog_to_prompt(self, dialog):
tokenizer = self.model.tokenizer.tokenizer
Expand Down Expand Up @@ -85,29 +94,36 @@ def batched(cls, iterable, n):
yield batch

@classmethod
def run_in_batches(cls, inference_function, inputs, batch_size, max_generator_tokens=None):
def run_in_batches(cls,
inference_function,
inputs,
batch_size,
max_generator_tokens=DEFAULT_MAX_GENERATOR_TOKENS,
**generation_kwargs):
''' Batch inference to avoid out of memory error'''
outputs = []
for batch in cls.batched(inputs, batch_size):
output = inference_function(list(batch), max_tokens=max_generator_tokens)
output = inference_function(list(batch), max_new_tokens=max_generator_tokens, **generation_kwargs)
if not isinstance(output, list):
output = [output]
outputs.extend(output)
return outputs

def run_inference(self, prompts, schema):
generator = outlines.generate.json(
json_schema = JsonSchema(schema, whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
schema,
sampler=self.sampler,
whitespace_pattern=r"[ ]?")
json_schema)

if isinstance(prompts, str):
return generator(prompts, max_tokens=self.max_generator_tokens)
output = generator(prompts, max_new_tokens=self.max_generator_tokens, **self.generation_kwargs)
return json.loads(output)
elif isinstance(prompts, Iterable):
return self.run_in_batches(
generator, prompts, self.inference_batch_size, self.max_generator_tokens
output = self.run_in_batches(
generator.batch, prompts, self.inference_batch_size, self.max_generator_tokens, **self.generation_kwargs
)
return [json.loads(r) for r in output]
else:
raise TypeError("Don't know how to run inference on provided "
"`prompts` object")
Expand All @@ -116,7 +132,7 @@ def run_inference_unstructured(self, prompts):
generator = outlines.generate.regex(
self.model,
r'.*', # "allow anything" regex
sampler=self.sampler)
**self.generation_kwargs)

if isinstance(prompts, str):
return generator(prompts, self.max_generator_tokens)
Expand All @@ -135,18 +151,12 @@ def cache_repr(self):
object instances, it's assumed that inference output will be
the same
'''
def _sampler_repr(sampler):
return "{}.{}({})".format(
sampler.__class__.__module__,
sampler.__class__.__name__,
", ".join([f"{k}={v}" for k, v in vars(sampler).items()]))

return dedent(f"""
{self.__class__.__module__}.{self.__class__.__name__}(
model_name="{self.model_name}",
precision="{self.precision}",
sampler={_sampler_repr(self.sampler)},
inference_batch_size={self.inference_batch_size},
model_kwargs={self.model_kwargs},
tokenizer_kwargs={self.tokenizer_kwargs},
generation_kwargs={self.generation_kwargs},
)""").strip()
Loading