Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
5278458
make it compatible with chunked hidden states selective log softmax
pluesclues Nov 7, 2025
65d6d9f
Merge branch 'unslothai:main' into alternative_compute_chunked_loss
pluesclues Nov 11, 2025
1e49528
Merge branch 'unslothai:main' into alternative_compute_chunked_loss
pluesclues Nov 12, 2025
494f611
Refactor grpo_trainer for logps and entropies handling
pluesclues Nov 12, 2025
387939f
Update fmt.Println message from 'Hello World'
pluesclues Nov 12, 2025
eccd41d
Merge branch 'unslothai:main' into alternative_compute_chunked_loss
pluesclues Nov 17, 2025
95abf46
Refactor chunking logic for pixel values and image grid
pluesclues Nov 17, 2025
52b23ff
Merge branch 'unslothai:main' into alternative_compute_chunked_loss
pluesclues Nov 18, 2025
f2102c8
Refactor padding logic with max_left_pad handling
pluesclues Nov 18, 2025
16f6be6
Merge branch 'unslothai:main' into alternative_compute_chunked_loss
pluesclues Nov 19, 2025
ac15b81
Clean up padding logic and remove unused comments
pluesclues Nov 19, 2025
f49bf4f
Merge branch 'unslothai:main' into alternative_compute_chunked_loss
pluesclues Nov 20, 2025
ea6964a
Update vllm usage conditions with importance sampling check
pluesclues Nov 20, 2025
ca6b826
Disable TRL importance sampling logic
pluesclues Nov 21, 2025
f2b29ee
Refactor error handling in rl_replacements.py
pluesclues Nov 21, 2025
8d263b1
Refactor vllm_importance_sampling_correction checks
pluesclues Nov 21, 2025
d332e93
Add grpo_selective_log_softmax to RL replacements
pluesclues Nov 21, 2025
4be35d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2025
84c56aa
Refactor code for readability and consistency
pluesclues Nov 21, 2025
9b2539c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def unsloth_prediction_step(
exec(f"Trainer.prediction_step=unsloth_prediction_step")


grpo_selective_log_softmax = RL_REPLACEMENTS["grpo_selective_log_softmax"]
selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"]
calculate_pad_tokens_in_prompt = RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"]
create_completion_attention_mask = RL_REPLACEMENTS["create_completion_attention_mask"]
Expand Down Expand Up @@ -253,6 +254,7 @@ def wrapper(self, *args, **kwargs):
"triton.cudagraphs" : False,
}}

{grpo_selective_log_softmax_code}
{selective_log_softmax_code}
{calculate_pad_tokens_in_prompt_code}
{create_completion_attention_mask_code}
Expand Down Expand Up @@ -909,6 +911,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):

# Selective log softmax and other functions
selective_log_softmax_code = inspect.getsource(selective_log_softmax)
grpo_selective_log_softmax_code = inspect.getsource(grpo_selective_log_softmax)
calculate_pad_tokens_in_prompt_code = inspect.getsource(
calculate_pad_tokens_in_prompt
)
Expand Down Expand Up @@ -938,6 +941,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
max_seq_length_call = max_seq_length_call,
max_seq_length_post = max_seq_length_post,
selective_log_softmax_code = selective_log_softmax_code,
grpo_selective_log_softmax_code = grpo_selective_log_softmax_code,
calculate_pad_tokens_in_prompt_code = calculate_pad_tokens_in_prompt_code,
create_completion_attention_mask_code = create_completion_attention_mask_code,
left_pack_padding_code = left_pack_padding_code,
Expand Down
216 changes: 171 additions & 45 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,20 @@ def grpo_trainer__generate_and_score_completions(function_name, function):

# The new multi-line string that will replace the line above
replacement_lines = """
max_left_pad = None
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
try:
# TRL 0.23.1 and below path
if not has_images:
# Left pad prompt before calculation old and ref hidden states
prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)
self.model.for_training()
left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)
max_left_pad = max(left_pad_tokens_per_prompt).item()
except:
# TRL 0.24.0 and below path
if images is None:
# Left pad prompt before calculation old and ref hidden states
prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)
left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)
max_left_pad = max(left_pad_tokens_per_prompt).item()
self.model.for_training()"""

function = function.replace(line_to_replace, replacement_lines)
Expand Down Expand Up @@ -319,18 +321,27 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
if self.use_vllm:"""
function = function.replace(replace_part, new_replacement)

# Important note: we disable TRL's importance sampling logic
string_to_find = "if self.use_vllm and self.vllm_importance_sampling_correction:"

replacement_string = (
"if False and self.use_vllm and self.vllm_importance_sampling_correction:"
)

function = function.replace(string_to_find, replacement_string)

string_to_find = """ if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]"""

replacement_string = """ if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]

if self.use_vllm:
try:
if max_left_pad is not None:
output["max_left_pad"] = torch.tensor(sampling_per_token_logps.shape[0] * [max_left_pad]).unsqueeze(-1)
try:
if self.use_vllm and getattr(self, "vllm_importance_sampling_correction", False):
output["sampling_per_token_logps"] = sampling_per_token_logps
except NameError:
output["sampling_per_token_logps"] = None"""

except NameError:
output["sampling_per_token_logps"] = None"""
function = function.replace(string_to_find, replacement_string)

if "wake_up()" not in function:
Expand Down Expand Up @@ -510,7 +521,6 @@ def _get_per_token_logps_and_entropies(
if compute_efficient:
return None, None
else:
# Otherwise, calculate normally:
if not hasattr(self, "_autocast_dtype"):
self._autocast_dtype = (
torch.float16
Expand All @@ -529,47 +539,156 @@ def _get_per_token_logps_and_entropies(
kwargs.get("image_sizes", None),
)

os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"

unwrapped_model = self.accelerator.unwrap_model(
model, keep_fp32_wrapper = False
)

B = input_ids.shape[0]
all_logprobs_list = []

if pixel_values is None:
left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(
input_ids, logits_to_keep, self.processing_class.pad_token_id
)
max_left_pad = max(left_pad_tokens_per_prompt).item()
input_ids = left_pack_padding(
input_ids, self.processing_class.pad_token_id
)
attention_mask = input_ids != self.processing_class.pad_token_id
attention_mask = attention_mask.to(attention_mask.dtype)
else:
max_left_pad = 0

input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0)
attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0)

def chunk_optional(tensor, chunks):
if tensor is None:
return [None] * chunks
return torch.chunk(tensor, chunks = chunks, dim = 0)

pixel_values_chunks = [None] * B
image_grid_thw_chunks = [None] * B
pixel_attention_mask_chunks = [None] * B

# This is the chunkng logit from trl 0.23.0
if image_grid_thw is not None and pixel_values is not None:
if image_grid_thw.shape[0] != B:
raise ValueError(
f"This logic requires image_grid_thw.shape[0] ({image_grid_thw.shape[0]}) "
f"to be equal to batch size B ({B})."
)

rows_per_sample = image_grid_thw.prod(dim = -1)
rows_per_sample_list = rows_per_sample.cpu().tolist()

pixel_values_chunks = list(
torch.split(pixel_values, rows_per_sample_list, dim = 0)
)
if pixel_attention_mask is not None:
pixel_attention_mask_chunks = list(
torch.split(pixel_attention_mask, rows_per_sample_list, dim = 0)
)

image_grid_thw_chunks = list(
torch.chunk(image_grid_thw, chunks = B, dim = 0)
)

elif pixel_values is not None:
pixel_values_chunks = list(torch.chunk(pixel_values, chunks = B, dim = 0))
if pixel_attention_mask is not None:
pixel_attention_mask_chunks = list(
torch.chunk(pixel_attention_mask, chunks = B, dim = 0)
)

if image_sizes is not None and not isinstance(image_sizes, torch.Tensor):
image_sizes_chunks = [[size] for size in image_sizes]
else:
image_sizes_chunks = chunk_optional(image_sizes, B)

lm_head = self.model.get_output_embeddings().weight
temperature = self.temperature
logit_softcapping = getattr(model.config, "final_logit_softcapping", 0)
if logit_softcapping is None:
logit_softcapping = 0
logit_scale_multiply = getattr(model.config, "logit_scale", 0)
if logit_scale_multiply is None:
logit_scale_multiply = 0
logit_scale_divide = getattr(model.config, "logits_scaling", 0)
if logit_scale_divide is None:
logit_scale_divide = 0

zipped_inputs = zip(
input_ids_chunks,
attention_mask_chunks,
pixel_values_chunks,
image_grid_thw_chunks,
pixel_attention_mask_chunks,
image_sizes_chunks,
)
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"

with torch.amp.autocast(device_type = "cuda", dtype = self._autocast_dtype):
with torch.inference_mode():
if pixel_values is None:
attention_mask = input_ids != self.processing_class.pad_token_id
attention_mask = attention_mask.to(attention_mask.dtype)
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = unwrapped_model(
input_ids = input_ids,
attention_mask = attention_mask,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
# logits_to_keep = logits_to_keep + 1,
).logits
else:
logits = unwrapped_model(
input_ids = input_ids,
attention_mask = attention_mask,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
logits_to_keep = logits_to_keep + 1,
).logits

entropies = None
if compute_entropy:
from trl.trainer.utils import entropy_from_logits

entropies = entropy_from_logits(logits)
with torch.no_grad():
for (
input_ids_chunk,
attention_mask_chunk,
pixel_values_chunk,
image_grid_thw_chunk,
pixel_attention_mask_chunk,
image_sizes_chunk,
) in zipped_inputs:
if pixel_values is None:
logits_chunk = unwrapped_model(
input_ids = input_ids_chunk,
attention_mask = attention_mask_chunk,
pixel_values = pixel_values_chunk,
image_grid_thw = image_grid_thw_chunk,
pixel_attention_mask = pixel_attention_mask_chunk,
image_sizes = image_sizes_chunk,
).logits

completion_input_ids_chunk = input_ids_chunk[
:, -(logits_to_keep + max_left_pad) :
]
logits_chunk = logits_chunk[
:, -(logits_to_keep + max_left_pad + 1) :, :
]
logits_chunk = logits_chunk[:, :-1, :]
else:
logits_chunk = unwrapped_model(
input_ids = input_ids_chunk,
attention_mask = attention_mask_chunk,
pixel_values = pixel_values_chunk,
image_grid_thw = image_grid_thw_chunk,
pixel_attention_mask = pixel_attention_mask_chunk,
image_sizes = image_sizes_chunk,
logits_to_keep = logits_to_keep + 1,
).logits

logits_chunk = logits_chunk[:, :-1, :]
completion_input_ids_chunk = input_ids_chunk[
:, -logits_to_keep:
]
# breakpoint()
logprobs_chunk = chunked_hidden_states_selective_log_softmax(
logits_chunk,
lm_head,
completion_input_ids_chunk,
chunks = 8,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
)

all_logprobs_list.append(logprobs_chunk)
logprobs = torch.cat(all_logprobs_list, dim = 0)
entropies = None

os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
# logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
return logits.detach(), entropies # logps, entropies

return logprobs.detach(), entropies # logps, entropies
# input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
Expand Down Expand Up @@ -700,6 +819,7 @@ def compute_loss(
if logit_scale_divide is None:
logit_scale_divide = 0

max_left_pad = inputs.get("max_left_pad", 0)
if per_token_logps is not None:
if ref_hidden_states is not None:
ref_hidden_states = ref_hidden_states[
Expand Down Expand Up @@ -731,6 +851,7 @@ def compute_loss(
max_completion_length = self.args.max_completion_length,
delta = self.args.delta,
temperature = self.args.temperature,
max_left_pad = max_left_pad,
logit_softcapping = logit_softcapping,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
Expand Down Expand Up @@ -761,6 +882,7 @@ def compute_loss(
max_completion_length = self.args.max_completion_length,
delta = self.args.delta,
temperature = self.args.temperature,
max_left_pad = max_left_pad,
logit_softcapping = logit_softcapping,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
Expand Down Expand Up @@ -797,7 +919,11 @@ def compute_loss(
self._metrics["completion_length"].append(completion_length.item())
self._metrics["kl"].append(mean_kl.item())

if self.use_vllm and delta is not None:
if (
self.use_vllm
and delta is not None
and getattr(self, "vllm_importance_sampling_correction", False)
):
mean_delta = (
torch.mean(delta)
if delta.numel() > 0
Expand Down