From cd2ff615d9511afb501668f811a170e9b4e226e7 Mon Sep 17 00:00:00 2001 From: Egor Date: Mon, 28 Jul 2025 21:35:39 +0300 Subject: [PATCH 1/2] fix --- torchtune/dev/rl/rewards.py | 100 +++++++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/torchtune/dev/rl/rewards.py b/torchtune/dev/rl/rewards.py index 8d1ec1e79f..47404ba6c7 100644 --- a/torchtune/dev/rl/rewards.py +++ b/torchtune/dev/rl/rewards.py @@ -7,10 +7,17 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union + +import math_verify import torch +from torchtune.modules.transforms.tokenizers import ( + HuggingFaceModelTokenizer, + ModelTokenizer, +) + @dataclass class RewardOutput: @@ -216,3 +223,94 @@ def __call__( }, successes=successes, ) + + +def at_least_one_space_between_think_tags( + cot: str, answer: str, potential_answer: str +) -> tuple[float, float]: + """Did the model at least try to think?""" + if len(cot) > 0: + return 1.0, 1.0 # (reward, success) + else: + return 0.0, 0.0 + + +def math_response_correct( + cot: str, answer: str, potential_answer: str +) -> tuple[float, float]: + """Did it get the right answer?""" + if potential_answer is None: + return 0.0, 0.0 # (reward, success) + gold = math_verify.parse(answer) + attempt = math_verify.parse(potential_answer) + + if math_verify.verify(gold, attempt): + return 100.0, 1.0 + if answer in potential_answer: + return 50.0, 0.0 + if len(potential_answer) > 0: + return 1.0, 0.0 + return 0.0, 0.0 + + +def extract_tags(text: str) -> tuple[str, str]: + """ + Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'. + The values are lists of strings, with each string being the content of a tag. + """ + think_pattern = r"(.*?)" + answer_pattern = r"(.*?)" + think_match = re.search(think_pattern, text, re.DOTALL) + answer_match = re.search(answer_pattern, text, re.DOTALL) + cot = think_match.group(1).strip() if think_match else "" + potential_answer = answer_match.group(1).strip() if answer_match else "" + return cot, potential_answer + + +def batched_rewards( + tokenizer: Union[ModelTokenizer, HuggingFaceModelTokenizer], + completions: torch.Tensor, + answers: list[str], + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, dict]: + + reward_funcs = [ + at_least_one_space_between_think_tags, + math_response_correct, + ] + + num_reward_funcs = len(reward_funcs) + + batch_size, grpo_size, _ = completions.shape + + # TODO: should this be bfloat16? + + rewards_tensor = torch.zeros( + batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device + ) + + successes_tensor = torch.zeros( + batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device + ) + + metadata = {"func_names": [f.__name__ for f in reward_funcs]} + + for b in range(batch_size): + + for g in range(grpo_size): + + answer = answers[b][g] + + text_completion = tokenizer.decode(completions[b, g].tolist()) + + cot, potential_answer = extract_tags(f"{text_completion}") + + for rw_idx, reward_func in enumerate(reward_funcs): + + reward, success = reward_func(cot, answer, potential_answer) + + rewards_tensor[b, g, rw_idx] += reward + + successes_tensor[b, g, rw_idx] += success + + return rewards_tensor, successes_tensor, metadata From ffd36500730abfa42423c758197ecf3e93f1a875 Mon Sep 17 00:00:00 2001 From: Egor Date: Mon, 28 Jul 2025 21:46:31 +0300 Subject: [PATCH 2/2] fix import --- torchtune/dev/rl/rewards.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/dev/rl/rewards.py b/torchtune/dev/rl/rewards.py index 47404ba6c7..95c45ee9b0 100644 --- a/torchtune/dev/rl/rewards.py +++ b/torchtune/dev/rl/rewards.py @@ -9,8 +9,6 @@ from dataclasses import dataclass, field from typing import Optional, Union -import math_verify - import torch from torchtune.modules.transforms.tokenizers import ( @@ -239,6 +237,8 @@ def math_response_correct( cot: str, answer: str, potential_answer: str ) -> tuple[float, float]: """Did it get the right answer?""" + import math_verify + if potential_answer is None: return 0.0, 0.0 # (reward, success) gold = math_verify.parse(answer)