diff --git a/tinker_cookbook/recipes/rubric/README.md b/tinker_cookbook/recipes/rubric/README.md new file mode 100644 index 0000000..92a3026 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/README.md @@ -0,0 +1,94 @@ +# Rubric-based Grading for LLMs + +- [`data.py`](./data.py) contains the definition for the datapoint class. Each datapoint consists of a conversation prefix and a list of rubric items. +- [`generate_data.py`](./generate_data.py) generates some example datapoints if you want to run our demo on addition. +- [`env.py`](./env.py) determines what each rollout will do. It will let the policy read the prefix, generate a response, ask a grader LLM to grade based on a list of rubric items, and finally provide a reward by summing the response of each grader. +- [`train.py`](./train.py) allows you to train LLMs on any dataset saved in our format (specified in `data.py`). The default script will train on the addition task, whose data is generated by `generate_data.py`. +- [`prometheus_experimental.py`](./prometheus_experimental.py) contains a script to train the LLMs based on the rubrics from the [`prometheus-eval/Feedback-Collection`](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/viewer/default/train?row=0&views%5B%5D=train) dataset. It is experimental though -- even though the reward goes up, there is no guarantee that the model is actually better. We hope our script serves as a starting point, and more research is needed. + + +## A simple example of using a grader LLM with rubrics + +We show how to use a rubric-based LLM to provide a reward for an addition task. E.g. + +``` +**User**: What's 233 + 100? +**Assistant**: 333 +``` + +Usually, this could be graded by matching the number to the ground truth 333 without needing an LLM. However, for pedagogical purposes, we will grade the response using a language model with a rubric. That is, we will ask a language model "Does the assistant answer 333?" + +### Generate an example dataset + +To run this, first generate a dataset: + +``` +python -m tinker_cookbook.recipes.rubric.generate_data +``` + +Then you will see two `jsonl` files generated, one for training, one for testing. For example, if you look into `tinker_cookbook/example_data/example_rubric_train.jsonl`, each datapoint consists of +- a convo (the conversation prefix that the policy sees) +- rubric_items: a list of rubric items that specify what is a good response, how the grader should format the response, and how the grading result should be extracted. + +``` +{ + "convo": [ + { + "role": "user", + "content": "What is 4 + 5?" + }, + { + "role": "assistant", + "content": "9" + }, + { + "role": "user", + "content": "What is 122 + 12?" + } + ], + "rubric_items": [ + { + "rubric_str": "Does the chatbot correctly get the answer 134?", + "extraction_regex": "(.*)", + "grader_output_format_instruction": "Please output your score between 0 and 1 wrapped in ... " + } + ] +} +``` + +### Debugging and Printing What Happens During Rollouts + +Run +``` +python -m tinker_cookbook.recipes.rubric.debug_env +``` + +You can see the message that the policy sees, its response, the grader input, and the grader output. + +image + + +### An example training run + +To train the LLM to add with a rubric-based LLM, run +``` +python -m tinker_cookbook.recipes.rubric.train +``` + +You can see the reward quickly goes up. + +image + +### A more realistic dataset + +We take the `prometheus-eval/Feedback-Collection` dataset from [Hugging Face](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/), which contains rubrics to grade general chat responses. Run the following to kick off training: + +``` +python -m tinker_cookbook.recipes.rubric.prometheus_experimental +``` + +We can see that the reward climbs up steadily. + +image + +Note that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting point for you to improve rubric-based grading for training LLMs! diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py new file mode 100644 index 0000000..c9011b6 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/data.py @@ -0,0 +1,169 @@ +from tinker_cookbook.renderers import ( + Message, + Role, +) +from typing import TypeAlias +from dataclasses import dataclass +from typing import Sequence +import re +import json +import chz + +Conversation: TypeAlias = list[Message] + + +@dataclass +class Rubric: + """ + A rubric should specify 1) what counts as a good response, 2) how the grader language model should output the score, and 3) how to extract the score from the grader's response. + """ + + rubric_str: str + extraction_regex: str = r"(.*)" + grader_output_format_instruction: str = ( + "Please output your score between 0 and 1 wrapped in ... " + ) + + def __convert_role(self, role: Role) -> str: + return "Human" if role in ("user", "system") else "Chatbot" + + def _flatten_convo(self, convo: Conversation) -> str: + """ + Convert the whole conversation (user's turns + assistant's turns) into a single string. E.g. + \n\nHuman: .... + \n\nChatbot: ... + \n\nHuman: ... + \n\nChatbot: ... + """ + return "\n\n".join( + [f"{self.__convert_role(message['role'])}: {message['content']}" for message in convo] + ) + + def get_grader_prompt(self, convo: Conversation) -> Conversation: + """ + Create a prompt for the grader to grade the conversation based on the rubric. The prompt should contain 1) the conversation to be graded, and 2) the rubric. + """ + + prompt = "I will show you 1) a conversation between a human and a chatbot, and 2) a rubric for grading the conversation. Please grade the conversation based on the rubric." + + prompt += f"Here is the conversation: \n\n{self._flatten_convo(convo)} \n\n\n\nHere is the rubric: \n{self.rubric_str}\n\n" + prompt += f"Please grade the conversation based on the rubric. {self.grader_output_format_instruction}" + return [ + { + "role": "user", + "content": prompt, + } + ] + + def extract_score(self, response: str) -> float: + match = re.search(self.extraction_regex, response, re.DOTALL) + if match is not None: + try: + return float(match.group(1)) + except ValueError: + print(f"Warning: Failed to extract score from grader response: {response}") + return 0.0 + else: + print(f"Warning: Failed to extract score from grader response: {response}") + return 0.0 + + def to_dict(self) -> dict[str, str]: + return { + "rubric_str": self.rubric_str, + "extraction_regex": self.extraction_regex, + "grader_output_format_instruction": self.grader_output_format_instruction, + } + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @staticmethod + def from_dict(d: dict[str, str]) -> "Rubric": + return Rubric( + rubric_str=d["rubric_str"], + extraction_regex=d["extraction_regex"], + grader_output_format_instruction=d["grader_output_format_instruction"], + ) + + @staticmethod + def from_json(json_str: str) -> "Rubric": + return Rubric.from_dict(json.loads(json_str)) + + +@dataclass(frozen=True) +class RubricBasedDatapoint: + """ + A rubric-based datapoint contains a conversation and a rubric. + In this task, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric. + """ + + convo: Conversation + rubric_items: Sequence[Rubric] + + def to_json(self) -> str: + return json.dumps( + { + "convo": self.convo, + "rubric_items": [rubric.to_dict() for rubric in self.rubric_items], + } + ) + + @staticmethod + def from_json(json_str: str) -> "RubricBasedDatapoint": + d = json.loads(json_str) + return RubricBasedDatapoint( + convo=d["convo"], + rubric_items=[Rubric.from_dict(rubric) for rubric in d["rubric_items"]], + ) + + +@chz.chz +class RubricDatapointListBuilder: + def __call__(self) -> Sequence[RubricBasedDatapoint]: + raise NotImplementedError("Subclass must implement this method") + + +@chz.chz +class RubricDatapointListBuilderFromJsonl(RubricDatapointListBuilder): + jsonl_path: str + + def __call__(self) -> Sequence[RubricBasedDatapoint]: + datapoints = [] + with open(self.jsonl_path, "r") as f: + for line in f: + datapoints.append(RubricBasedDatapoint.from_json(line)) + return datapoints + + +@chz.chz +class PrometheusDatapointListBuilder(RubricDatapointListBuilder): + data_path: str = "prometheus-eval/Feedback-Collection" + + def __call__(self) -> Sequence[RubricBasedDatapoint]: + from datasets import load_dataset + + train_dataset = load_dataset(self.data_path)["train"] + return [self.build_rubric_datapoint(item) for item in train_dataset] # type: ignore + + def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint: + convo: Conversation = [ + {"role": "user", "content": item["orig_instruction"]}, + ] + + rubric_text = f"Your job is to evaluate the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.\n" + rubric_text += "Here is the calibration for each score:\n" + for i in range(1, 6): + rubric_text += f"{i}.0: {item[f'orig_score{i}_description']}\n" + + rubric_text += f"\nHere is a reference response that achieved a score of 5: {item['orig_reference_answer']}\n" + + rubric = Rubric( + rubric_str=rubric_text, + extraction_regex=r"(.*)", + grader_output_format_instruction="Please output your score between 1 and 5 wrapped in ... ", + ) + + return RubricBasedDatapoint( + convo=convo, + rubric_items=[rubric], + ) diff --git a/tinker_cookbook/recipes/rubric/debug_env.py b/tinker_cookbook/recipes/rubric/debug_env.py new file mode 100644 index 0000000..84950bb --- /dev/null +++ b/tinker_cookbook/recipes/rubric/debug_env.py @@ -0,0 +1,74 @@ +from tinker_cookbook import model_info +from tinker_cookbook.recipes.rubric.env import RubricGradedEnv, RubricBasedDatapoint, Rubric +from tinker_cookbook.completers import TinkerMessageCompleter, TinkerTokenCompleter +from tinker_cookbook.renderers import get_renderer +from tinker_cookbook.tokenizer_utils import get_tokenizer +import tinker +from tinker_cookbook.rl.rollouts import do_single_rollout +import asyncio + + +def get_addition_datapoint() -> RubricBasedDatapoint: + datapoint = RubricBasedDatapoint( + convo=[ + {"role": "user", "content": "What is 4 + 5?"}, + {"role": "assistant", "content": "9"}, + {"role": "user", "content": "What is 125 + 311?"}, + ], + rubric_items=[ + Rubric(rubric_str="Does the chatbot correctly get the answer 436?"), + Rubric(rubric_str="Does the chatbot provide an answer without saying anything else?"), + ], + ) + + return datapoint + + +def get_prometheus_datapoint() -> RubricBasedDatapoint: + from tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder + + datapoint = PrometheusDatapointListBuilder()() + datapoint = datapoint[0] + return datapoint + + +async def main(datapoint: RubricBasedDatapoint): + policy_name = "meta-llama/Llama-3.1-8B-Instruct" + grader_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" + service_client = tinker.ServiceClient() + policy = TinkerTokenCompleter( + sampling_client=service_client.create_sampling_client(base_model=policy_name), + max_tokens=64, + ) + policy_renderer = get_renderer( + model_info.get_recommended_renderer_name(policy_name), get_tokenizer(policy_name) + ) + grader = TinkerMessageCompleter( + sampling_client=service_client.create_sampling_client(base_model=grader_name), + renderer=get_renderer( + model_info.get_recommended_renderer_name(grader_name), get_tokenizer(grader_name) + ), + max_tokens=64, + ) + + env = RubricGradedEnv( + renderer=policy_renderer, + datapoint=datapoint, + grader_llm=grader, + debug=True, + ) + + await do_single_rollout(policy, env) + + +if __name__ == "__main__": + dataset = "addition" + + if dataset == "addition": + datapoint = get_addition_datapoint() + asyncio.run(main(datapoint)) + elif dataset == "prometheus": + datapoint = get_prometheus_datapoint() + asyncio.run(main(datapoint)) + else: + raise ValueError(f"Unknown dataset: {dataset}") diff --git a/tinker_cookbook/recipes/rubric/env.py b/tinker_cookbook/recipes/rubric/env.py new file mode 100644 index 0000000..218f052 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/env.py @@ -0,0 +1,214 @@ +from tinker_cookbook.rl.types import ( + Action, + Env, + StepResult, + EnvGroupBuilder, + RLDataset, + RLDatasetBuilder, +) +from tinker_cookbook.renderers import Renderer +from tinker_cookbook.completers import MessageCompleter, StopCondition, TinkerMessageCompleter +from tinker.types import ModelInput +from dataclasses import dataclass +from typing import Sequence +import json +import chz +import tinker +from tinker_cookbook.tokenizer_utils import get_tokenizer +from tinker_cookbook.renderers import get_renderer +import asyncio +from tinker_cookbook import model_info +from tinker_cookbook.recipes.rubric.data import ( + RubricBasedDatapoint, + Rubric, + Conversation, + RubricDatapointListBuilder, +) + +# ANSI color codes +BLUE = "\033[94m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +MAGENTA = "\033[95m" +RESET = "\033[0m" + + +class RubricGradedEnv(Env): + def __init__( + self, + renderer: Renderer, + datapoint: RubricBasedDatapoint, + grader_llm: MessageCompleter, + debug: bool = False, + ): + """ + Initialize the RubricGradedEnv. In this environment, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric. + """ + self.renderer = renderer + self.datapoint = datapoint + self.grader_llm = grader_llm + self.debug = debug + + @property + def rubric_items(self) -> Sequence[Rubric]: + return self.datapoint.rubric_items + + @property + def convo(self) -> Conversation: + return self.datapoint.convo + + @property + def stop_condition(self) -> StopCondition: + return self.renderer.get_stop_sequences() + + async def initial_observation(self) -> tuple[ModelInput, StopCondition]: + return self.renderer.build_generation_prompt(self.convo), self.stop_condition + + async def _grade_with_rubric(self, convo: Conversation, rubric: Rubric) -> float: + # this is the conversation for the grader + # effectively it's just one user turn + grader_prompt = rubric.get_grader_prompt(convo) + + # obtain the response from the grader and convert it to a score + grader_response = await self.grader_llm(grader_prompt) + grader_response_content = grader_response["content"] + assert isinstance(grader_response_content, str), "Grader response content must be a string" + score = rubric.extract_score(grader_response_content) + if self.debug: + print(f"{YELLOW}{'=' * 80}") + print("DEBUG: First Turn of Grader Prompt") + print(f"{'=' * 80}{RESET}") + print(f"{YELLOW}{grader_prompt[0]['content']}{RESET}\n") + + print(f"{MAGENTA}{'=' * 80}") + print("DEBUG: Score") + print(f"{'=' * 80}{RESET}") + print(f"{MAGENTA}Grader Response: {grader_response_content}{RESET}\n") + print(f"{MAGENTA}Extracted Score: {score}{RESET}\n") + return score + + async def step(self, action: Action) -> StepResult: + # obtain the policy action message + (policy_action_message, _parse_success) = self.renderer.parse_response(action) + + if self.debug: + print(f"\n{BLUE}{'=' * 80}") + print("DEBUG: Original Conversation (self.convo)") + print(f"{'=' * 80}{RESET}") + print(f"{BLUE}{json.dumps(self.convo, indent=2)}{RESET}\n") + + print(f"{GREEN}{'=' * 80}") + print("DEBUG: Policy Action Message") + print(f"{'=' * 80}{RESET}") + print(f"{GREEN}{json.dumps(policy_action_message, indent=2)}{RESET}\n") + # this shows the full back-and-forth conversation to the grader + convo = self.convo + [policy_action_message] + + scores = await asyncio.gather( + *[self._grade_with_rubric(convo, rubric_item) for rubric_item in self.rubric_items] + ) + avg_score = sum(scores) / len(scores) + + return StepResult( + reward=avg_score, + episode_done=True, + next_observation=self.renderer.build_generation_prompt(convo), + next_stop_condition=self.stop_condition, + ) + + +@dataclass(frozen=True) +class RubricGradedEnvGroupBuilder(EnvGroupBuilder): + renderer: Renderer + datapoint: RubricBasedDatapoint + grader_llm: MessageCompleter + group_size: int + + async def make_envs(self) -> Sequence[RubricGradedEnv]: + return [ + RubricGradedEnv( + renderer=self.renderer, + datapoint=self.datapoint, + grader_llm=self.grader_llm, + ) + for _ in range(self.group_size) + ] + + +@dataclass(frozen=True) +class RubricGradedDataset(RLDataset): + renderer: Renderer + batch_size: int + group_size: int + datapoints: Sequence[RubricBasedDatapoint] + grader_llm: MessageCompleter + + def get_batch(self, index: int) -> Sequence[RubricGradedEnvGroupBuilder]: + batch = [ + RubricGradedEnvGroupBuilder( + renderer=self.renderer, + datapoint=self.datapoints[index * self.batch_size + i], + grader_llm=self.grader_llm, + group_size=self.group_size, + ) + for i in range(self.batch_size) + ] + return batch + + def __len__(self) -> int: + return len(self.datapoints) // self.batch_size + + +@chz.chz +class RubricGradedDatasetBuilder(RLDatasetBuilder): + renderer_name: str + model_name_for_tokenizer: str + batch_size: int + train_group_size: int + test_group_size: int = 1 + + train_datapoint_list_builder: RubricDatapointListBuilder + test_datapoint_list_builder: RubricDatapointListBuilder | None = None + + base_url: str | None = None + grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507" + + def _get_grader_llm(self) -> MessageCompleter: + tokenizer = get_tokenizer(self.grader_llm_name) + renderer_name = model_info.get_recommended_renderer_name(self.grader_llm_name) + renderer = get_renderer(name=renderer_name, tokenizer=tokenizer) + service_client = tinker.ServiceClient(base_url=self.base_url) + sampling_client = service_client.create_sampling_client(base_model=self.grader_llm_name) + return TinkerMessageCompleter( + sampling_client=sampling_client, renderer=renderer, max_tokens=2048 + ) + + async def __call__(self) -> tuple[RubricGradedDataset, RubricGradedDataset | None]: + train_datapoints = self.train_datapoint_list_builder() + test_datapoints = None + if self.test_datapoint_list_builder is not None: + test_datapoints = self.test_datapoint_list_builder() + + renderer = get_renderer( + name=self.renderer_name, tokenizer=get_tokenizer(self.model_name_for_tokenizer) + ) + + assert train_datapoints is not None, "Train datapoints are required" + train_dataset = RubricGradedDataset( + renderer=renderer, + batch_size=self.batch_size, + group_size=self.train_group_size, + datapoints=train_datapoints, + grader_llm=self._get_grader_llm(), + ) + if test_datapoints is None: + return train_dataset, None + else: + test_dataset = RubricGradedDataset( + renderer=renderer, + batch_size=len(test_datapoints), + group_size=self.test_group_size, + datapoints=test_datapoints, + grader_llm=self._get_grader_llm(), + ) + return train_dataset, test_dataset diff --git a/tinker_cookbook/recipes/rubric/generate_data.py b/tinker_cookbook/recipes/rubric/generate_data.py new file mode 100644 index 0000000..922db4d --- /dev/null +++ b/tinker_cookbook/recipes/rubric/generate_data.py @@ -0,0 +1,46 @@ +from tinker_cookbook.recipes.rubric.data import RubricBasedDatapoint, Rubric +import random +import os + + +def generate_one(rng: random.Random) -> RubricBasedDatapoint: + x, y = rng.randint(0, 1000), rng.randint(0, 1000) + return RubricBasedDatapoint( + convo=[ + {"role": "user", "content": "What is 4 + 5?"}, + {"role": "assistant", "content": "9"}, + {"role": "user", "content": f"What is {x} + {y}?"}, + ], + rubric_items=[Rubric(rubric_str=f"Does the chatbot correctly get the answer {x + y}?")], + ) + + +def generate_dataset( + num_train: int, num_test: int, seed: int, write_dir: str = "tinker_cookbook/example_data/" +) -> tuple[str, str]: + random.seed(seed) + rng = random.Random(seed) + total_datapoints = num_train + num_test + datapoints = [generate_one(rng) for _ in range(total_datapoints)] + + train_datapoints = datapoints[:num_train] + train_jsonl_path = os.path.join(write_dir, "example_rubric_train.jsonl") + with open(train_jsonl_path, "w") as f: + for datapoint in train_datapoints: + f.write(datapoint.to_json() + "\n") + print(f"Generated {len(train_datapoints)} train datapoints in {train_jsonl_path}") + + test_datapoints = datapoints[num_train:] + test_jsonl_path = os.path.join(write_dir, "example_rubric_test.jsonl") + with open(test_jsonl_path, "w") as f: + for datapoint in test_datapoints: + f.write(datapoint.to_json() + "\n") + print(f"Generated {len(test_datapoints)} test datapoints in {test_jsonl_path}") + + return train_jsonl_path, test_jsonl_path + + +if __name__ == "__main__": + train_jsonl_path, test_jsonl_path = generate_dataset(num_train=10000, num_test=1000, seed=42) + print(f"Generated train dataset in {train_jsonl_path}") + print(f"Generated test dataset in {test_jsonl_path}") diff --git a/tinker_cookbook/recipes/rubric/prometheus_experimental.py b/tinker_cookbook/recipes/rubric/prometheus_experimental.py new file mode 100644 index 0000000..2d7eb79 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/prometheus_experimental.py @@ -0,0 +1,140 @@ +import chz +import asyncio +from datetime import datetime +from tinker_cookbook import cli_utils, model_info +from tinker_cookbook.rl.train import AsyncConfig, Config, main +from tinker_cookbook.rl.types import RLDatasetBuilder +from tinker.types import LossFnType +from tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder +from tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder + + +@chz.chz +class CLIConfig: + """Simple command-line configuration for RL training.""" + + # Model configuration + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + lora_rank: int = 32 + renderer_name: str | None = None + load_checkpoint_path: str | None = None + + seed: int = 0 # Random seed for data shuffling + + # Training hyperparameters + train_group_size: int = 4 + test_group_size: int = 1 + groups_per_batch: int = 100 + learning_rate: float = 1e-5 + max_tokens: int = 5 + temperature: float = 1.0 + kl_penalty_coef: float = 0.0 + grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507" + # Number of optimizer steps per training iteration. + # Useful for very large batch sizes. + num_substeps: int = 1 + + # Logging configuration + log_path: str | None = None + wandb_project: str | None = None + wandb_name: str | None = None + compute_post_kl: bool = False + + # Evals + eval_every: int = 20 + + # Checkpointing + save_every: int = 20 + + # Service configuration + base_url: str | None = None + + behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask" + + max_steps_off_policy: int | None = None + loss_fn: LossFnType = "importance_sampling" + + +def get_dataset_builder( + batch_size: int, + policy_model_name: str, + renderer_name: str, + grader_llm_name: str, + train_group_size: int, + test_group_size: int = 1, +) -> RLDatasetBuilder: + return RubricGradedDatasetBuilder( + batch_size=batch_size, + model_name_for_tokenizer=policy_model_name, + renderer_name=renderer_name, + grader_llm_name=grader_llm_name, + train_datapoint_list_builder=PrometheusDatapointListBuilder(), + test_datapoint_list_builder=None, + train_group_size=train_group_size, + test_group_size=test_group_size, + ) + + +async def cli_main(cli_config: CLIConfig): + """Convert CLI config to full config and run training.""" + + # Get tokenizer for stop sequences + renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name( + cli_config.model_name + ) + model_name = cli_config.model_name.replace("/", "-") + run_name = f"prometheus_experimental-{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.train_group_size}group_size-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" + # create log path if it doesn't exist + if cli_config.log_path is not None: + log_path = cli_config.log_path + else: + log_path = f"/tmp/tinker-examples/rubric/{run_name}" + + if cli_config.wandb_name is not None: + wandb_name = cli_config.wandb_name + else: + wandb_name = run_name + + # Create full config + config = Config( + learning_rate=cli_config.learning_rate, + dataset_builder=get_dataset_builder( + batch_size=cli_config.groups_per_batch, + policy_model_name=cli_config.model_name, + renderer_name=renderer_name, + grader_llm_name=cli_config.grader_llm_name, + train_group_size=cli_config.train_group_size, + test_group_size=cli_config.test_group_size, + ), + model_name=cli_config.model_name, + lora_rank=cli_config.lora_rank, + max_tokens=cli_config.max_tokens, + temperature=cli_config.temperature, + wandb_project=cli_config.wandb_project, + wandb_name=wandb_name, + log_path=log_path, + base_url=cli_config.base_url, + load_checkpoint_path=cli_config.load_checkpoint_path, + compute_post_kl=cli_config.compute_post_kl, + kl_penalty_coef=cli_config.kl_penalty_coef, + num_substeps=cli_config.num_substeps, + eval_every=cli_config.eval_every, + save_every=cli_config.save_every, + async_config=AsyncConfig( + max_steps_off_policy=cli_config.max_steps_off_policy, + groups_per_batch=cli_config.groups_per_batch, + ) + if cli_config.max_steps_off_policy is not None + else None, + loss_fn=cli_config.loss_fn, + ) + + cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) + + # Run training + await main(config) + + +if __name__ == "__main__": + cli_config = chz.entrypoint(CLIConfig) + asyncio.run(cli_main(cli_config)) diff --git a/tinker_cookbook/recipes/rubric/train.py b/tinker_cookbook/recipes/rubric/train.py new file mode 100644 index 0000000..88953d4 --- /dev/null +++ b/tinker_cookbook/recipes/rubric/train.py @@ -0,0 +1,151 @@ +import chz +import asyncio +from datetime import datetime +from tinker_cookbook import cli_utils, model_info +from tinker_cookbook.rl.train import AsyncConfig, Config, main +from tinker_cookbook.rl.types import RLDatasetBuilder +from tinker.types import LossFnType +from tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder +from tinker_cookbook.recipes.rubric.data import RubricDatapointListBuilderFromJsonl + + +@chz.chz +class CLIConfig: + """Simple command-line configuration for RL training.""" + + # Model configuration + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + lora_rank: int = 32 + renderer_name: str | None = None + load_checkpoint_path: str | None = None + + seed: int = 0 # Random seed for data shuffling + + # Training hyperparameters + train_group_size: int = 4 + test_group_size: int = 1 + groups_per_batch: int = 100 + learning_rate: float = 1e-5 + max_tokens: int = 5 + temperature: float = 1.0 + kl_penalty_coef: float = 0.0 + grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507" + train_jsonl_path: str = "tinker_cookbook/example_data/example_rubric_train.jsonl" + test_jsonl_path: str = "tinker_cookbook/example_data/example_rubric_test.jsonl" + + # Number of optimizer steps per training iteration. + # Useful for very large batch sizes. + num_substeps: int = 1 + + # Logging configuration + log_path: str | None = None + wandb_project: str | None = None + wandb_name: str | None = None + compute_post_kl: bool = False + + # Evals + eval_every: int = 20 + + # Checkpointing + save_every: int = 20 + + # Service configuration + base_url: str | None = None + + behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask" + + max_steps_off_policy: int | None = None + loss_fn: LossFnType = "importance_sampling" + + +def get_dataset_builder( + batch_size: int, + policy_model_name: str, + renderer_name: str, + grader_llm_name: str, + train_group_size: int, + train_jsonl_path: str, + test_jsonl_path: str | None = None, + test_group_size: int = 1, +) -> RLDatasetBuilder: + return RubricGradedDatasetBuilder( + batch_size=batch_size, + model_name_for_tokenizer=policy_model_name, + renderer_name=renderer_name, + grader_llm_name=grader_llm_name, + train_datapoint_list_builder=RubricDatapointListBuilderFromJsonl( + jsonl_path=train_jsonl_path + ), + test_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(jsonl_path=test_jsonl_path) + if test_jsonl_path is not None + else None, + train_group_size=train_group_size, + test_group_size=test_group_size, + ) + + +async def cli_main(cli_config: CLIConfig): + """Convert CLI config to full config and run training.""" + + # Get tokenizer for stop sequences + renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name( + cli_config.model_name + ) + model_name = cli_config.model_name.replace("/", "-") + run_name = f"{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.train_group_size}group_size-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" + # create log path if it doesn't exist + if cli_config.log_path is not None: + log_path = cli_config.log_path + else: + log_path = f"/tmp/tinker-examples/rubric/{run_name}" + + if cli_config.wandb_name is not None: + wandb_name = cli_config.wandb_name + else: + wandb_name = run_name + + # Create full config + config = Config( + learning_rate=cli_config.learning_rate, + dataset_builder=get_dataset_builder( + batch_size=cli_config.groups_per_batch, + policy_model_name=cli_config.model_name, + renderer_name=renderer_name, + grader_llm_name=cli_config.grader_llm_name, + train_group_size=cli_config.train_group_size, + train_jsonl_path=cli_config.train_jsonl_path, + test_jsonl_path=cli_config.test_jsonl_path, + test_group_size=cli_config.test_group_size, + ), + model_name=cli_config.model_name, + lora_rank=cli_config.lora_rank, + max_tokens=cli_config.max_tokens, + temperature=cli_config.temperature, + wandb_project=cli_config.wandb_project, + wandb_name=wandb_name, + log_path=log_path, + base_url=cli_config.base_url, + load_checkpoint_path=cli_config.load_checkpoint_path, + compute_post_kl=cli_config.compute_post_kl, + kl_penalty_coef=cli_config.kl_penalty_coef, + num_substeps=cli_config.num_substeps, + eval_every=cli_config.eval_every, + save_every=cli_config.save_every, + async_config=AsyncConfig( + max_steps_off_policy=cli_config.max_steps_off_policy, + groups_per_batch=cli_config.groups_per_batch, + ) + if cli_config.max_steps_off_policy is not None + else None, + loss_fn=cli_config.loss_fn, + ) + + cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) + + # Run training + await main(config) + + +if __name__ == "__main__": + cli_config = chz.entrypoint(CLIConfig) + asyncio.run(cli_main(cli_config))