Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 99 additions & 1 deletion torchtune/dev/rl/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union

import torch

from torchtune.modules.transforms.tokenizers import (
HuggingFaceModelTokenizer,
ModelTokenizer,
)


@dataclass
class RewardOutput:
Expand Down Expand Up @@ -216,3 +221,96 @@ def __call__(
},
successes=successes,
)


def at_least_one_space_between_think_tags(
cot: str, answer: str, potential_answer: str
) -> tuple[float, float]:
"""Did the model at least try to think?"""
if len(cot) > 0:
return 1.0, 1.0 # (reward, success)
else:
return 0.0, 0.0


def math_response_correct(
cot: str, answer: str, potential_answer: str
) -> tuple[float, float]:
"""Did it get the right answer?"""
import math_verify

if potential_answer is None:
return 0.0, 0.0 # (reward, success)
gold = math_verify.parse(answer)
attempt = math_verify.parse(potential_answer)

if math_verify.verify(gold, attempt):
return 100.0, 1.0
if answer in potential_answer:
return 50.0, 0.0
if len(potential_answer) > 0:
return 1.0, 0.0
return 0.0, 0.0


def extract_tags(text: str) -> tuple[str, str]:
"""
Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'.
The values are lists of strings, with each string being the content of a tag.
"""
think_pattern = r"<think>(.*?)</think>"
answer_pattern = r"<answer>(.*?)</answer>"
think_match = re.search(think_pattern, text, re.DOTALL)
answer_match = re.search(answer_pattern, text, re.DOTALL)
cot = think_match.group(1).strip() if think_match else ""
potential_answer = answer_match.group(1).strip() if answer_match else ""
return cot, potential_answer


def batched_rewards(
tokenizer: Union[ModelTokenizer, HuggingFaceModelTokenizer],
completions: torch.Tensor,
answers: list[str],
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, dict]:

reward_funcs = [
at_least_one_space_between_think_tags,
math_response_correct,
]

num_reward_funcs = len(reward_funcs)

batch_size, grpo_size, _ = completions.shape

# TODO: should this be bfloat16?

rewards_tensor = torch.zeros(
batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device
)

successes_tensor = torch.zeros(
batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device
)

metadata = {"func_names": [f.__name__ for f in reward_funcs]}

for b in range(batch_size):

for g in range(grpo_size):

answer = answers[b][g]

text_completion = tokenizer.decode(completions[b, g].tolist())

cot, potential_answer = extract_tags(f"<think>{text_completion}")

for rw_idx, reward_func in enumerate(reward_funcs):

reward, success = reward_func(cot, answer, potential_answer)

rewards_tensor[b, g, rw_idx] += reward

successes_tensor[b, g, rw_idx] += success

return rewards_tensor, successes_tensor, metadata
Loading