diff --git a/tinker_cookbook/recipes/rubric/README.md b/tinker_cookbook/recipes/rubric/README.md
new file mode 100644
index 0000000..92a3026
--- /dev/null
+++ b/tinker_cookbook/recipes/rubric/README.md
@@ -0,0 +1,94 @@
+# Rubric-based Grading for LLMs
+
+- [`data.py`](./data.py) contains the definition for the datapoint class. Each datapoint consists of a conversation prefix and a list of rubric items.
+- [`generate_data.py`](./generate_data.py) generates some example datapoints if you want to run our demo on addition.
+- [`env.py`](./env.py) determines what each rollout will do. It will let the policy read the prefix, generate a response, ask a grader LLM to grade based on a list of rubric items, and finally provide a reward by summing the response of each grader.
+- [`train.py`](./train.py) allows you to train LLMs on any dataset saved in our format (specified in `data.py`). The default script will train on the addition task, whose data is generated by `generate_data.py`.
+- [`prometheus_experimental.py`](./prometheus_experimental.py) contains a script to train the LLMs based on the rubrics from the [`prometheus-eval/Feedback-Collection`](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/viewer/default/train?row=0&views%5B%5D=train) dataset. It is experimental though -- even though the reward goes up, there is no guarantee that the model is actually better. We hope our script serves as a starting point, and more research is needed.
+
+
+## A simple example of using a grader LLM with rubrics
+
+We show how to use a rubric-based LLM to provide a reward for an addition task. E.g.
+
+```
+**User**: What's 233 + 100?
+**Assistant**: 333
+```
+
+Usually, this could be graded by matching the number to the ground truth 333 without needing an LLM. However, for pedagogical purposes, we will grade the response using a language model with a rubric. That is, we will ask a language model "Does the assistant answer 333?"
+
+### Generate an example dataset
+
+To run this, first generate a dataset:
+
+```
+python -m tinker_cookbook.recipes.rubric.generate_data
+```
+
+Then you will see two `jsonl` files generated, one for training, one for testing. For example, if you look into `tinker_cookbook/example_data/example_rubric_train.jsonl`, each datapoint consists of
+- a convo (the conversation prefix that the policy sees)
+- rubric_items: a list of rubric items that specify what is a good response, how the grader should format the response, and how the grading result should be extracted.
+
+```
+{
+ "convo": [
+ {
+ "role": "user",
+ "content": "What is 4 + 5?"
+ },
+ {
+ "role": "assistant",
+ "content": "9"
+ },
+ {
+ "role": "user",
+ "content": "What is 122 + 12?"
+ }
+ ],
+ "rubric_items": [
+ {
+ "rubric_str": "Does the chatbot correctly get the answer 134?",
+ "extraction_regex": "(.*)",
+ "grader_output_format_instruction": "Please output your score between 0 and 1 wrapped in ... "
+ }
+ ]
+}
+```
+
+### Debugging and Printing What Happens During Rollouts
+
+Run
+```
+python -m tinker_cookbook.recipes.rubric.debug_env
+```
+
+You can see the message that the policy sees, its response, the grader input, and the grader output.
+
+
+
+
+### An example training run
+
+To train the LLM to add with a rubric-based LLM, run
+```
+python -m tinker_cookbook.recipes.rubric.train
+```
+
+You can see the reward quickly goes up.
+
+
+
+### A more realistic dataset
+
+We take the `prometheus-eval/Feedback-Collection` dataset from [Hugging Face](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/), which contains rubrics to grade general chat responses. Run the following to kick off training:
+
+```
+python -m tinker_cookbook.recipes.rubric.prometheus_experimental
+```
+
+We can see that the reward climbs up steadily.
+
+
+
+Note that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting point for you to improve rubric-based grading for training LLMs!
diff --git a/tinker_cookbook/recipes/rubric/data.py b/tinker_cookbook/recipes/rubric/data.py
new file mode 100644
index 0000000..c9011b6
--- /dev/null
+++ b/tinker_cookbook/recipes/rubric/data.py
@@ -0,0 +1,169 @@
+from tinker_cookbook.renderers import (
+ Message,
+ Role,
+)
+from typing import TypeAlias
+from dataclasses import dataclass
+from typing import Sequence
+import re
+import json
+import chz
+
+Conversation: TypeAlias = list[Message]
+
+
+@dataclass
+class Rubric:
+ """
+ A rubric should specify 1) what counts as a good response, 2) how the grader language model should output the score, and 3) how to extract the score from the grader's response.
+ """
+
+ rubric_str: str
+ extraction_regex: str = r"(.*)"
+ grader_output_format_instruction: str = (
+ "Please output your score between 0 and 1 wrapped in ... "
+ )
+
+ def __convert_role(self, role: Role) -> str:
+ return "Human" if role in ("user", "system") else "Chatbot"
+
+ def _flatten_convo(self, convo: Conversation) -> str:
+ """
+ Convert the whole conversation (user's turns + assistant's turns) into a single string. E.g.
+ \n\nHuman: ....
+ \n\nChatbot: ...
+ \n\nHuman: ...
+ \n\nChatbot: ...
+ """
+ return "\n\n".join(
+ [f"{self.__convert_role(message['role'])}: {message['content']}" for message in convo]
+ )
+
+ def get_grader_prompt(self, convo: Conversation) -> Conversation:
+ """
+ Create a prompt for the grader to grade the conversation based on the rubric. The prompt should contain 1) the conversation to be graded, and 2) the rubric.
+ """
+
+ prompt = "I will show you 1) a conversation between a human and a chatbot, and 2) a rubric for grading the conversation. Please grade the conversation based on the rubric."
+
+ prompt += f"Here is the conversation: \n\n{self._flatten_convo(convo)} \n\n\n\nHere is the rubric: \n{self.rubric_str}\n\n"
+ prompt += f"Please grade the conversation based on the rubric. {self.grader_output_format_instruction}"
+ return [
+ {
+ "role": "user",
+ "content": prompt,
+ }
+ ]
+
+ def extract_score(self, response: str) -> float:
+ match = re.search(self.extraction_regex, response, re.DOTALL)
+ if match is not None:
+ try:
+ return float(match.group(1))
+ except ValueError:
+ print(f"Warning: Failed to extract score from grader response: {response}")
+ return 0.0
+ else:
+ print(f"Warning: Failed to extract score from grader response: {response}")
+ return 0.0
+
+ def to_dict(self) -> dict[str, str]:
+ return {
+ "rubric_str": self.rubric_str,
+ "extraction_regex": self.extraction_regex,
+ "grader_output_format_instruction": self.grader_output_format_instruction,
+ }
+
+ def to_json(self) -> str:
+ return json.dumps(self.to_dict())
+
+ @staticmethod
+ def from_dict(d: dict[str, str]) -> "Rubric":
+ return Rubric(
+ rubric_str=d["rubric_str"],
+ extraction_regex=d["extraction_regex"],
+ grader_output_format_instruction=d["grader_output_format_instruction"],
+ )
+
+ @staticmethod
+ def from_json(json_str: str) -> "Rubric":
+ return Rubric.from_dict(json.loads(json_str))
+
+
+@dataclass(frozen=True)
+class RubricBasedDatapoint:
+ """
+ A rubric-based datapoint contains a conversation and a rubric.
+ In this task, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric.
+ """
+
+ convo: Conversation
+ rubric_items: Sequence[Rubric]
+
+ def to_json(self) -> str:
+ return json.dumps(
+ {
+ "convo": self.convo,
+ "rubric_items": [rubric.to_dict() for rubric in self.rubric_items],
+ }
+ )
+
+ @staticmethod
+ def from_json(json_str: str) -> "RubricBasedDatapoint":
+ d = json.loads(json_str)
+ return RubricBasedDatapoint(
+ convo=d["convo"],
+ rubric_items=[Rubric.from_dict(rubric) for rubric in d["rubric_items"]],
+ )
+
+
+@chz.chz
+class RubricDatapointListBuilder:
+ def __call__(self) -> Sequence[RubricBasedDatapoint]:
+ raise NotImplementedError("Subclass must implement this method")
+
+
+@chz.chz
+class RubricDatapointListBuilderFromJsonl(RubricDatapointListBuilder):
+ jsonl_path: str
+
+ def __call__(self) -> Sequence[RubricBasedDatapoint]:
+ datapoints = []
+ with open(self.jsonl_path, "r") as f:
+ for line in f:
+ datapoints.append(RubricBasedDatapoint.from_json(line))
+ return datapoints
+
+
+@chz.chz
+class PrometheusDatapointListBuilder(RubricDatapointListBuilder):
+ data_path: str = "prometheus-eval/Feedback-Collection"
+
+ def __call__(self) -> Sequence[RubricBasedDatapoint]:
+ from datasets import load_dataset
+
+ train_dataset = load_dataset(self.data_path)["train"]
+ return [self.build_rubric_datapoint(item) for item in train_dataset] # type: ignore
+
+ def build_rubric_datapoint(self, item: dict) -> RubricBasedDatapoint:
+ convo: Conversation = [
+ {"role": "user", "content": item["orig_instruction"]},
+ ]
+
+ rubric_text = f"Your job is to evaluate the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.\n"
+ rubric_text += "Here is the calibration for each score:\n"
+ for i in range(1, 6):
+ rubric_text += f"{i}.0: {item[f'orig_score{i}_description']}\n"
+
+ rubric_text += f"\nHere is a reference response that achieved a score of 5: {item['orig_reference_answer']}\n"
+
+ rubric = Rubric(
+ rubric_str=rubric_text,
+ extraction_regex=r"(.*)",
+ grader_output_format_instruction="Please output your score between 1 and 5 wrapped in ... ",
+ )
+
+ return RubricBasedDatapoint(
+ convo=convo,
+ rubric_items=[rubric],
+ )
diff --git a/tinker_cookbook/recipes/rubric/debug_env.py b/tinker_cookbook/recipes/rubric/debug_env.py
new file mode 100644
index 0000000..84950bb
--- /dev/null
+++ b/tinker_cookbook/recipes/rubric/debug_env.py
@@ -0,0 +1,74 @@
+from tinker_cookbook import model_info
+from tinker_cookbook.recipes.rubric.env import RubricGradedEnv, RubricBasedDatapoint, Rubric
+from tinker_cookbook.completers import TinkerMessageCompleter, TinkerTokenCompleter
+from tinker_cookbook.renderers import get_renderer
+from tinker_cookbook.tokenizer_utils import get_tokenizer
+import tinker
+from tinker_cookbook.rl.rollouts import do_single_rollout
+import asyncio
+
+
+def get_addition_datapoint() -> RubricBasedDatapoint:
+ datapoint = RubricBasedDatapoint(
+ convo=[
+ {"role": "user", "content": "What is 4 + 5?"},
+ {"role": "assistant", "content": "9"},
+ {"role": "user", "content": "What is 125 + 311?"},
+ ],
+ rubric_items=[
+ Rubric(rubric_str="Does the chatbot correctly get the answer 436?"),
+ Rubric(rubric_str="Does the chatbot provide an answer without saying anything else?"),
+ ],
+ )
+
+ return datapoint
+
+
+def get_prometheus_datapoint() -> RubricBasedDatapoint:
+ from tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder
+
+ datapoint = PrometheusDatapointListBuilder()()
+ datapoint = datapoint[0]
+ return datapoint
+
+
+async def main(datapoint: RubricBasedDatapoint):
+ policy_name = "meta-llama/Llama-3.1-8B-Instruct"
+ grader_name = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+ service_client = tinker.ServiceClient()
+ policy = TinkerTokenCompleter(
+ sampling_client=service_client.create_sampling_client(base_model=policy_name),
+ max_tokens=64,
+ )
+ policy_renderer = get_renderer(
+ model_info.get_recommended_renderer_name(policy_name), get_tokenizer(policy_name)
+ )
+ grader = TinkerMessageCompleter(
+ sampling_client=service_client.create_sampling_client(base_model=grader_name),
+ renderer=get_renderer(
+ model_info.get_recommended_renderer_name(grader_name), get_tokenizer(grader_name)
+ ),
+ max_tokens=64,
+ )
+
+ env = RubricGradedEnv(
+ renderer=policy_renderer,
+ datapoint=datapoint,
+ grader_llm=grader,
+ debug=True,
+ )
+
+ await do_single_rollout(policy, env)
+
+
+if __name__ == "__main__":
+ dataset = "addition"
+
+ if dataset == "addition":
+ datapoint = get_addition_datapoint()
+ asyncio.run(main(datapoint))
+ elif dataset == "prometheus":
+ datapoint = get_prometheus_datapoint()
+ asyncio.run(main(datapoint))
+ else:
+ raise ValueError(f"Unknown dataset: {dataset}")
diff --git a/tinker_cookbook/recipes/rubric/env.py b/tinker_cookbook/recipes/rubric/env.py
new file mode 100644
index 0000000..218f052
--- /dev/null
+++ b/tinker_cookbook/recipes/rubric/env.py
@@ -0,0 +1,214 @@
+from tinker_cookbook.rl.types import (
+ Action,
+ Env,
+ StepResult,
+ EnvGroupBuilder,
+ RLDataset,
+ RLDatasetBuilder,
+)
+from tinker_cookbook.renderers import Renderer
+from tinker_cookbook.completers import MessageCompleter, StopCondition, TinkerMessageCompleter
+from tinker.types import ModelInput
+from dataclasses import dataclass
+from typing import Sequence
+import json
+import chz
+import tinker
+from tinker_cookbook.tokenizer_utils import get_tokenizer
+from tinker_cookbook.renderers import get_renderer
+import asyncio
+from tinker_cookbook import model_info
+from tinker_cookbook.recipes.rubric.data import (
+ RubricBasedDatapoint,
+ Rubric,
+ Conversation,
+ RubricDatapointListBuilder,
+)
+
+# ANSI color codes
+BLUE = "\033[94m"
+GREEN = "\033[92m"
+YELLOW = "\033[93m"
+MAGENTA = "\033[95m"
+RESET = "\033[0m"
+
+
+class RubricGradedEnv(Env):
+ def __init__(
+ self,
+ renderer: Renderer,
+ datapoint: RubricBasedDatapoint,
+ grader_llm: MessageCompleter,
+ debug: bool = False,
+ ):
+ """
+ Initialize the RubricGradedEnv. In this environment, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric.
+ """
+ self.renderer = renderer
+ self.datapoint = datapoint
+ self.grader_llm = grader_llm
+ self.debug = debug
+
+ @property
+ def rubric_items(self) -> Sequence[Rubric]:
+ return self.datapoint.rubric_items
+
+ @property
+ def convo(self) -> Conversation:
+ return self.datapoint.convo
+
+ @property
+ def stop_condition(self) -> StopCondition:
+ return self.renderer.get_stop_sequences()
+
+ async def initial_observation(self) -> tuple[ModelInput, StopCondition]:
+ return self.renderer.build_generation_prompt(self.convo), self.stop_condition
+
+ async def _grade_with_rubric(self, convo: Conversation, rubric: Rubric) -> float:
+ # this is the conversation for the grader
+ # effectively it's just one user turn
+ grader_prompt = rubric.get_grader_prompt(convo)
+
+ # obtain the response from the grader and convert it to a score
+ grader_response = await self.grader_llm(grader_prompt)
+ grader_response_content = grader_response["content"]
+ assert isinstance(grader_response_content, str), "Grader response content must be a string"
+ score = rubric.extract_score(grader_response_content)
+ if self.debug:
+ print(f"{YELLOW}{'=' * 80}")
+ print("DEBUG: First Turn of Grader Prompt")
+ print(f"{'=' * 80}{RESET}")
+ print(f"{YELLOW}{grader_prompt[0]['content']}{RESET}\n")
+
+ print(f"{MAGENTA}{'=' * 80}")
+ print("DEBUG: Score")
+ print(f"{'=' * 80}{RESET}")
+ print(f"{MAGENTA}Grader Response: {grader_response_content}{RESET}\n")
+ print(f"{MAGENTA}Extracted Score: {score}{RESET}\n")
+ return score
+
+ async def step(self, action: Action) -> StepResult:
+ # obtain the policy action message
+ (policy_action_message, _parse_success) = self.renderer.parse_response(action)
+
+ if self.debug:
+ print(f"\n{BLUE}{'=' * 80}")
+ print("DEBUG: Original Conversation (self.convo)")
+ print(f"{'=' * 80}{RESET}")
+ print(f"{BLUE}{json.dumps(self.convo, indent=2)}{RESET}\n")
+
+ print(f"{GREEN}{'=' * 80}")
+ print("DEBUG: Policy Action Message")
+ print(f"{'=' * 80}{RESET}")
+ print(f"{GREEN}{json.dumps(policy_action_message, indent=2)}{RESET}\n")
+ # this shows the full back-and-forth conversation to the grader
+ convo = self.convo + [policy_action_message]
+
+ scores = await asyncio.gather(
+ *[self._grade_with_rubric(convo, rubric_item) for rubric_item in self.rubric_items]
+ )
+ avg_score = sum(scores) / len(scores)
+
+ return StepResult(
+ reward=avg_score,
+ episode_done=True,
+ next_observation=self.renderer.build_generation_prompt(convo),
+ next_stop_condition=self.stop_condition,
+ )
+
+
+@dataclass(frozen=True)
+class RubricGradedEnvGroupBuilder(EnvGroupBuilder):
+ renderer: Renderer
+ datapoint: RubricBasedDatapoint
+ grader_llm: MessageCompleter
+ group_size: int
+
+ async def make_envs(self) -> Sequence[RubricGradedEnv]:
+ return [
+ RubricGradedEnv(
+ renderer=self.renderer,
+ datapoint=self.datapoint,
+ grader_llm=self.grader_llm,
+ )
+ for _ in range(self.group_size)
+ ]
+
+
+@dataclass(frozen=True)
+class RubricGradedDataset(RLDataset):
+ renderer: Renderer
+ batch_size: int
+ group_size: int
+ datapoints: Sequence[RubricBasedDatapoint]
+ grader_llm: MessageCompleter
+
+ def get_batch(self, index: int) -> Sequence[RubricGradedEnvGroupBuilder]:
+ batch = [
+ RubricGradedEnvGroupBuilder(
+ renderer=self.renderer,
+ datapoint=self.datapoints[index * self.batch_size + i],
+ grader_llm=self.grader_llm,
+ group_size=self.group_size,
+ )
+ for i in range(self.batch_size)
+ ]
+ return batch
+
+ def __len__(self) -> int:
+ return len(self.datapoints) // self.batch_size
+
+
+@chz.chz
+class RubricGradedDatasetBuilder(RLDatasetBuilder):
+ renderer_name: str
+ model_name_for_tokenizer: str
+ batch_size: int
+ train_group_size: int
+ test_group_size: int = 1
+
+ train_datapoint_list_builder: RubricDatapointListBuilder
+ test_datapoint_list_builder: RubricDatapointListBuilder | None = None
+
+ base_url: str | None = None
+ grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+
+ def _get_grader_llm(self) -> MessageCompleter:
+ tokenizer = get_tokenizer(self.grader_llm_name)
+ renderer_name = model_info.get_recommended_renderer_name(self.grader_llm_name)
+ renderer = get_renderer(name=renderer_name, tokenizer=tokenizer)
+ service_client = tinker.ServiceClient(base_url=self.base_url)
+ sampling_client = service_client.create_sampling_client(base_model=self.grader_llm_name)
+ return TinkerMessageCompleter(
+ sampling_client=sampling_client, renderer=renderer, max_tokens=2048
+ )
+
+ async def __call__(self) -> tuple[RubricGradedDataset, RubricGradedDataset | None]:
+ train_datapoints = self.train_datapoint_list_builder()
+ test_datapoints = None
+ if self.test_datapoint_list_builder is not None:
+ test_datapoints = self.test_datapoint_list_builder()
+
+ renderer = get_renderer(
+ name=self.renderer_name, tokenizer=get_tokenizer(self.model_name_for_tokenizer)
+ )
+
+ assert train_datapoints is not None, "Train datapoints are required"
+ train_dataset = RubricGradedDataset(
+ renderer=renderer,
+ batch_size=self.batch_size,
+ group_size=self.train_group_size,
+ datapoints=train_datapoints,
+ grader_llm=self._get_grader_llm(),
+ )
+ if test_datapoints is None:
+ return train_dataset, None
+ else:
+ test_dataset = RubricGradedDataset(
+ renderer=renderer,
+ batch_size=len(test_datapoints),
+ group_size=self.test_group_size,
+ datapoints=test_datapoints,
+ grader_llm=self._get_grader_llm(),
+ )
+ return train_dataset, test_dataset
diff --git a/tinker_cookbook/recipes/rubric/generate_data.py b/tinker_cookbook/recipes/rubric/generate_data.py
new file mode 100644
index 0000000..922db4d
--- /dev/null
+++ b/tinker_cookbook/recipes/rubric/generate_data.py
@@ -0,0 +1,46 @@
+from tinker_cookbook.recipes.rubric.data import RubricBasedDatapoint, Rubric
+import random
+import os
+
+
+def generate_one(rng: random.Random) -> RubricBasedDatapoint:
+ x, y = rng.randint(0, 1000), rng.randint(0, 1000)
+ return RubricBasedDatapoint(
+ convo=[
+ {"role": "user", "content": "What is 4 + 5?"},
+ {"role": "assistant", "content": "9"},
+ {"role": "user", "content": f"What is {x} + {y}?"},
+ ],
+ rubric_items=[Rubric(rubric_str=f"Does the chatbot correctly get the answer {x + y}?")],
+ )
+
+
+def generate_dataset(
+ num_train: int, num_test: int, seed: int, write_dir: str = "tinker_cookbook/example_data/"
+) -> tuple[str, str]:
+ random.seed(seed)
+ rng = random.Random(seed)
+ total_datapoints = num_train + num_test
+ datapoints = [generate_one(rng) for _ in range(total_datapoints)]
+
+ train_datapoints = datapoints[:num_train]
+ train_jsonl_path = os.path.join(write_dir, "example_rubric_train.jsonl")
+ with open(train_jsonl_path, "w") as f:
+ for datapoint in train_datapoints:
+ f.write(datapoint.to_json() + "\n")
+ print(f"Generated {len(train_datapoints)} train datapoints in {train_jsonl_path}")
+
+ test_datapoints = datapoints[num_train:]
+ test_jsonl_path = os.path.join(write_dir, "example_rubric_test.jsonl")
+ with open(test_jsonl_path, "w") as f:
+ for datapoint in test_datapoints:
+ f.write(datapoint.to_json() + "\n")
+ print(f"Generated {len(test_datapoints)} test datapoints in {test_jsonl_path}")
+
+ return train_jsonl_path, test_jsonl_path
+
+
+if __name__ == "__main__":
+ train_jsonl_path, test_jsonl_path = generate_dataset(num_train=10000, num_test=1000, seed=42)
+ print(f"Generated train dataset in {train_jsonl_path}")
+ print(f"Generated test dataset in {test_jsonl_path}")
diff --git a/tinker_cookbook/recipes/rubric/prometheus_experimental.py b/tinker_cookbook/recipes/rubric/prometheus_experimental.py
new file mode 100644
index 0000000..2d7eb79
--- /dev/null
+++ b/tinker_cookbook/recipes/rubric/prometheus_experimental.py
@@ -0,0 +1,140 @@
+import chz
+import asyncio
+from datetime import datetime
+from tinker_cookbook import cli_utils, model_info
+from tinker_cookbook.rl.train import AsyncConfig, Config, main
+from tinker_cookbook.rl.types import RLDatasetBuilder
+from tinker.types import LossFnType
+from tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder
+from tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder
+
+
+@chz.chz
+class CLIConfig:
+ """Simple command-line configuration for RL training."""
+
+ # Model configuration
+ model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
+ lora_rank: int = 32
+ renderer_name: str | None = None
+ load_checkpoint_path: str | None = None
+
+ seed: int = 0 # Random seed for data shuffling
+
+ # Training hyperparameters
+ train_group_size: int = 4
+ test_group_size: int = 1
+ groups_per_batch: int = 100
+ learning_rate: float = 1e-5
+ max_tokens: int = 5
+ temperature: float = 1.0
+ kl_penalty_coef: float = 0.0
+ grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+ # Number of optimizer steps per training iteration.
+ # Useful for very large batch sizes.
+ num_substeps: int = 1
+
+ # Logging configuration
+ log_path: str | None = None
+ wandb_project: str | None = None
+ wandb_name: str | None = None
+ compute_post_kl: bool = False
+
+ # Evals
+ eval_every: int = 20
+
+ # Checkpointing
+ save_every: int = 20
+
+ # Service configuration
+ base_url: str | None = None
+
+ behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask"
+
+ max_steps_off_policy: int | None = None
+ loss_fn: LossFnType = "importance_sampling"
+
+
+def get_dataset_builder(
+ batch_size: int,
+ policy_model_name: str,
+ renderer_name: str,
+ grader_llm_name: str,
+ train_group_size: int,
+ test_group_size: int = 1,
+) -> RLDatasetBuilder:
+ return RubricGradedDatasetBuilder(
+ batch_size=batch_size,
+ model_name_for_tokenizer=policy_model_name,
+ renderer_name=renderer_name,
+ grader_llm_name=grader_llm_name,
+ train_datapoint_list_builder=PrometheusDatapointListBuilder(),
+ test_datapoint_list_builder=None,
+ train_group_size=train_group_size,
+ test_group_size=test_group_size,
+ )
+
+
+async def cli_main(cli_config: CLIConfig):
+ """Convert CLI config to full config and run training."""
+
+ # Get tokenizer for stop sequences
+ renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(
+ cli_config.model_name
+ )
+ model_name = cli_config.model_name.replace("/", "-")
+ run_name = f"prometheus_experimental-{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.train_group_size}group_size-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
+ # create log path if it doesn't exist
+ if cli_config.log_path is not None:
+ log_path = cli_config.log_path
+ else:
+ log_path = f"/tmp/tinker-examples/rubric/{run_name}"
+
+ if cli_config.wandb_name is not None:
+ wandb_name = cli_config.wandb_name
+ else:
+ wandb_name = run_name
+
+ # Create full config
+ config = Config(
+ learning_rate=cli_config.learning_rate,
+ dataset_builder=get_dataset_builder(
+ batch_size=cli_config.groups_per_batch,
+ policy_model_name=cli_config.model_name,
+ renderer_name=renderer_name,
+ grader_llm_name=cli_config.grader_llm_name,
+ train_group_size=cli_config.train_group_size,
+ test_group_size=cli_config.test_group_size,
+ ),
+ model_name=cli_config.model_name,
+ lora_rank=cli_config.lora_rank,
+ max_tokens=cli_config.max_tokens,
+ temperature=cli_config.temperature,
+ wandb_project=cli_config.wandb_project,
+ wandb_name=wandb_name,
+ log_path=log_path,
+ base_url=cli_config.base_url,
+ load_checkpoint_path=cli_config.load_checkpoint_path,
+ compute_post_kl=cli_config.compute_post_kl,
+ kl_penalty_coef=cli_config.kl_penalty_coef,
+ num_substeps=cli_config.num_substeps,
+ eval_every=cli_config.eval_every,
+ save_every=cli_config.save_every,
+ async_config=AsyncConfig(
+ max_steps_off_policy=cli_config.max_steps_off_policy,
+ groups_per_batch=cli_config.groups_per_batch,
+ )
+ if cli_config.max_steps_off_policy is not None
+ else None,
+ loss_fn=cli_config.loss_fn,
+ )
+
+ cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)
+
+ # Run training
+ await main(config)
+
+
+if __name__ == "__main__":
+ cli_config = chz.entrypoint(CLIConfig)
+ asyncio.run(cli_main(cli_config))
diff --git a/tinker_cookbook/recipes/rubric/train.py b/tinker_cookbook/recipes/rubric/train.py
new file mode 100644
index 0000000..88953d4
--- /dev/null
+++ b/tinker_cookbook/recipes/rubric/train.py
@@ -0,0 +1,151 @@
+import chz
+import asyncio
+from datetime import datetime
+from tinker_cookbook import cli_utils, model_info
+from tinker_cookbook.rl.train import AsyncConfig, Config, main
+from tinker_cookbook.rl.types import RLDatasetBuilder
+from tinker.types import LossFnType
+from tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder
+from tinker_cookbook.recipes.rubric.data import RubricDatapointListBuilderFromJsonl
+
+
+@chz.chz
+class CLIConfig:
+ """Simple command-line configuration for RL training."""
+
+ # Model configuration
+ model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
+ lora_rank: int = 32
+ renderer_name: str | None = None
+ load_checkpoint_path: str | None = None
+
+ seed: int = 0 # Random seed for data shuffling
+
+ # Training hyperparameters
+ train_group_size: int = 4
+ test_group_size: int = 1
+ groups_per_batch: int = 100
+ learning_rate: float = 1e-5
+ max_tokens: int = 5
+ temperature: float = 1.0
+ kl_penalty_coef: float = 0.0
+ grader_llm_name: str = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+ train_jsonl_path: str = "tinker_cookbook/example_data/example_rubric_train.jsonl"
+ test_jsonl_path: str = "tinker_cookbook/example_data/example_rubric_test.jsonl"
+
+ # Number of optimizer steps per training iteration.
+ # Useful for very large batch sizes.
+ num_substeps: int = 1
+
+ # Logging configuration
+ log_path: str | None = None
+ wandb_project: str | None = None
+ wandb_name: str | None = None
+ compute_post_kl: bool = False
+
+ # Evals
+ eval_every: int = 20
+
+ # Checkpointing
+ save_every: int = 20
+
+ # Service configuration
+ base_url: str | None = None
+
+ behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask"
+
+ max_steps_off_policy: int | None = None
+ loss_fn: LossFnType = "importance_sampling"
+
+
+def get_dataset_builder(
+ batch_size: int,
+ policy_model_name: str,
+ renderer_name: str,
+ grader_llm_name: str,
+ train_group_size: int,
+ train_jsonl_path: str,
+ test_jsonl_path: str | None = None,
+ test_group_size: int = 1,
+) -> RLDatasetBuilder:
+ return RubricGradedDatasetBuilder(
+ batch_size=batch_size,
+ model_name_for_tokenizer=policy_model_name,
+ renderer_name=renderer_name,
+ grader_llm_name=grader_llm_name,
+ train_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(
+ jsonl_path=train_jsonl_path
+ ),
+ test_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(jsonl_path=test_jsonl_path)
+ if test_jsonl_path is not None
+ else None,
+ train_group_size=train_group_size,
+ test_group_size=test_group_size,
+ )
+
+
+async def cli_main(cli_config: CLIConfig):
+ """Convert CLI config to full config and run training."""
+
+ # Get tokenizer for stop sequences
+ renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(
+ cli_config.model_name
+ )
+ model_name = cli_config.model_name.replace("/", "-")
+ run_name = f"{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.train_group_size}group_size-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
+ # create log path if it doesn't exist
+ if cli_config.log_path is not None:
+ log_path = cli_config.log_path
+ else:
+ log_path = f"/tmp/tinker-examples/rubric/{run_name}"
+
+ if cli_config.wandb_name is not None:
+ wandb_name = cli_config.wandb_name
+ else:
+ wandb_name = run_name
+
+ # Create full config
+ config = Config(
+ learning_rate=cli_config.learning_rate,
+ dataset_builder=get_dataset_builder(
+ batch_size=cli_config.groups_per_batch,
+ policy_model_name=cli_config.model_name,
+ renderer_name=renderer_name,
+ grader_llm_name=cli_config.grader_llm_name,
+ train_group_size=cli_config.train_group_size,
+ train_jsonl_path=cli_config.train_jsonl_path,
+ test_jsonl_path=cli_config.test_jsonl_path,
+ test_group_size=cli_config.test_group_size,
+ ),
+ model_name=cli_config.model_name,
+ lora_rank=cli_config.lora_rank,
+ max_tokens=cli_config.max_tokens,
+ temperature=cli_config.temperature,
+ wandb_project=cli_config.wandb_project,
+ wandb_name=wandb_name,
+ log_path=log_path,
+ base_url=cli_config.base_url,
+ load_checkpoint_path=cli_config.load_checkpoint_path,
+ compute_post_kl=cli_config.compute_post_kl,
+ kl_penalty_coef=cli_config.kl_penalty_coef,
+ num_substeps=cli_config.num_substeps,
+ eval_every=cli_config.eval_every,
+ save_every=cli_config.save_every,
+ async_config=AsyncConfig(
+ max_steps_off_policy=cli_config.max_steps_off_policy,
+ groups_per_batch=cli_config.groups_per_batch,
+ )
+ if cli_config.max_steps_off_policy is not None
+ else None,
+ loss_fn=cli_config.loss_fn,
+ )
+
+ cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)
+
+ # Run training
+ await main(config)
+
+
+if __name__ == "__main__":
+ cli_config = chz.entrypoint(CLIConfig)
+ asyncio.run(cli_main(cli_config))