-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Chunk Across Batch and Context length for logprob calculations for grpo #3628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: nightly
Are you sure you want to change the base?
Changes from all commits
5278458
65d6d9f
1e49528
494f611
387939f
eccd41d
95abf46
52b23ff
f2102c8
16f6be6
ac15b81
f49bf4f
ea6964a
ca6b826
f2b29ee
8d263b1
d332e93
4be35d8
84c56aa
9b2539c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -231,18 +231,20 @@ def grpo_trainer__generate_and_score_completions(function_name, function): | |
|
|
||
| # The new multi-line string that will replace the line above | ||
| replacement_lines = """ | ||
| max_left_pad = None | ||
| batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size | ||
| try: | ||
| # TRL 0.23.1 and below path | ||
| if not has_images: | ||
| # Left pad prompt before calculation old and ref hidden states | ||
| prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id) | ||
| self.model.for_training() | ||
| left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id) | ||
| max_left_pad = max(left_pad_tokens_per_prompt).item() | ||
| except: | ||
| # TRL 0.24.0 and below path | ||
| if images is None: | ||
| # Left pad prompt before calculation old and ref hidden states | ||
| prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id) | ||
| left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id) | ||
| max_left_pad = max(left_pad_tokens_per_prompt).item() | ||
| self.model.for_training()""" | ||
|
|
||
| function = function.replace(line_to_replace, replacement_lines) | ||
|
|
@@ -319,18 +321,27 @@ def grpo_trainer__generate_and_score_completions(function_name, function): | |
| if self.use_vllm:""" | ||
| function = function.replace(replace_part, new_replacement) | ||
|
|
||
| # Important note: we disable TRL's importance sampling logic | ||
| string_to_find = "if self.use_vllm and self.vllm_importance_sampling_correction:" | ||
|
|
||
| replacement_string = ( | ||
| "if False and self.use_vllm and self.vllm_importance_sampling_correction:" | ||
| ) | ||
|
|
||
| function = function.replace(string_to_find, replacement_string) | ||
|
|
||
| string_to_find = """ if "image_sizes" in prompt_inputs: | ||
| output["image_sizes"] = prompt_inputs["image_sizes"]""" | ||
|
|
||
| replacement_string = """ if "image_sizes" in prompt_inputs: | ||
| output["image_sizes"] = prompt_inputs["image_sizes"] | ||
| if self.use_vllm: | ||
| try: | ||
| if max_left_pad is not None: | ||
| output["max_left_pad"] = torch.tensor(sampling_per_token_logps.shape[0] * [max_left_pad]).unsqueeze(-1) | ||
| try: | ||
| if self.use_vllm and getattr(self, "vllm_importance_sampling_correction", False): | ||
| output["sampling_per_token_logps"] = sampling_per_token_logps | ||
| except NameError: | ||
| output["sampling_per_token_logps"] = None""" | ||
|
|
||
| except NameError: | ||
| output["sampling_per_token_logps"] = None""" | ||
| function = function.replace(string_to_find, replacement_string) | ||
|
|
||
| if "wake_up()" not in function: | ||
|
|
@@ -510,7 +521,6 @@ def _get_per_token_logps_and_entropies( | |
| if compute_efficient: | ||
| return None, None | ||
| else: | ||
| # Otherwise, calculate normally: | ||
| if not hasattr(self, "_autocast_dtype"): | ||
| self._autocast_dtype = ( | ||
| torch.float16 | ||
|
|
@@ -529,47 +539,156 @@ def _get_per_token_logps_and_entropies( | |
| kwargs.get("image_sizes", None), | ||
| ) | ||
|
|
||
| os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" | ||
|
|
||
| unwrapped_model = self.accelerator.unwrap_model( | ||
| model, keep_fp32_wrapper = False | ||
| ) | ||
|
|
||
| B = input_ids.shape[0] | ||
| all_logprobs_list = [] | ||
|
|
||
| if pixel_values is None: | ||
| left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt( | ||
| input_ids, logits_to_keep, self.processing_class.pad_token_id | ||
| ) | ||
| max_left_pad = max(left_pad_tokens_per_prompt).item() | ||
| input_ids = left_pack_padding( | ||
| input_ids, self.processing_class.pad_token_id | ||
| ) | ||
| attention_mask = input_ids != self.processing_class.pad_token_id | ||
| attention_mask = attention_mask.to(attention_mask.dtype) | ||
| else: | ||
| max_left_pad = 0 | ||
|
|
||
| input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0) | ||
| attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0) | ||
|
|
||
| def chunk_optional(tensor, chunks): | ||
| if tensor is None: | ||
| return [None] * chunks | ||
| return torch.chunk(tensor, chunks = chunks, dim = 0) | ||
|
|
||
| pixel_values_chunks = [None] * B | ||
| image_grid_thw_chunks = [None] * B | ||
| pixel_attention_mask_chunks = [None] * B | ||
|
|
||
| # This is the chunkng logit from trl 0.23.0 | ||
| if image_grid_thw is not None and pixel_values is not None: | ||
| if image_grid_thw.shape[0] != B: | ||
| raise ValueError( | ||
| f"This logic requires image_grid_thw.shape[0] ({image_grid_thw.shape[0]}) " | ||
| f"to be equal to batch size B ({B})." | ||
| ) | ||
|
|
||
| rows_per_sample = image_grid_thw.prod(dim = -1) | ||
| rows_per_sample_list = rows_per_sample.cpu().tolist() | ||
|
|
||
| pixel_values_chunks = list( | ||
| torch.split(pixel_values, rows_per_sample_list, dim = 0) | ||
| ) | ||
| if pixel_attention_mask is not None: | ||
| pixel_attention_mask_chunks = list( | ||
| torch.split(pixel_attention_mask, rows_per_sample_list, dim = 0) | ||
| ) | ||
|
|
||
| image_grid_thw_chunks = list( | ||
| torch.chunk(image_grid_thw, chunks = B, dim = 0) | ||
| ) | ||
|
|
||
| elif pixel_values is not None: | ||
| pixel_values_chunks = list(torch.chunk(pixel_values, chunks = B, dim = 0)) | ||
| if pixel_attention_mask is not None: | ||
| pixel_attention_mask_chunks = list( | ||
| torch.chunk(pixel_attention_mask, chunks = B, dim = 0) | ||
| ) | ||
|
|
||
| if image_sizes is not None and not isinstance(image_sizes, torch.Tensor): | ||
| image_sizes_chunks = [[size] for size in image_sizes] | ||
| else: | ||
| image_sizes_chunks = chunk_optional(image_sizes, B) | ||
|
|
||
| lm_head = self.model.get_output_embeddings().weight | ||
| temperature = self.temperature | ||
| logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) | ||
| if logit_softcapping is None: | ||
| logit_softcapping = 0 | ||
| logit_scale_multiply = getattr(model.config, "logit_scale", 0) | ||
| if logit_scale_multiply is None: | ||
| logit_scale_multiply = 0 | ||
| logit_scale_divide = getattr(model.config, "logits_scaling", 0) | ||
| if logit_scale_divide is None: | ||
| logit_scale_divide = 0 | ||
|
|
||
| zipped_inputs = zip( | ||
| input_ids_chunks, | ||
| attention_mask_chunks, | ||
| pixel_values_chunks, | ||
| image_grid_thw_chunks, | ||
| pixel_attention_mask_chunks, | ||
| image_sizes_chunks, | ||
| ) | ||
| os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" | ||
|
|
||
| with torch.amp.autocast(device_type = "cuda", dtype = self._autocast_dtype): | ||
| with torch.inference_mode(): | ||
| if pixel_values is None: | ||
| attention_mask = input_ids != self.processing_class.pad_token_id | ||
| attention_mask = attention_mask.to(attention_mask.dtype) | ||
| # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded | ||
| logits = unwrapped_model( | ||
| input_ids = input_ids, | ||
| attention_mask = attention_mask, | ||
| pixel_values = pixel_values, | ||
| image_grid_thw = image_grid_thw, | ||
| pixel_attention_mask = pixel_attention_mask, | ||
| image_sizes = image_sizes, | ||
| # logits_to_keep = logits_to_keep + 1, | ||
| ).logits | ||
| else: | ||
| logits = unwrapped_model( | ||
| input_ids = input_ids, | ||
| attention_mask = attention_mask, | ||
| pixel_values = pixel_values, | ||
| image_grid_thw = image_grid_thw, | ||
| pixel_attention_mask = pixel_attention_mask, | ||
| image_sizes = image_sizes, | ||
| logits_to_keep = logits_to_keep + 1, | ||
| ).logits | ||
|
|
||
| entropies = None | ||
| if compute_entropy: | ||
| from trl.trainer.utils import entropy_from_logits | ||
|
|
||
| entropies = entropy_from_logits(logits) | ||
| with torch.no_grad(): | ||
| for ( | ||
| input_ids_chunk, | ||
| attention_mask_chunk, | ||
| pixel_values_chunk, | ||
| image_grid_thw_chunk, | ||
| pixel_attention_mask_chunk, | ||
| image_sizes_chunk, | ||
| ) in zipped_inputs: | ||
| if pixel_values is None: | ||
| logits_chunk = unwrapped_model( | ||
| input_ids = input_ids_chunk, | ||
| attention_mask = attention_mask_chunk, | ||
| pixel_values = pixel_values_chunk, | ||
| image_grid_thw = image_grid_thw_chunk, | ||
| pixel_attention_mask = pixel_attention_mask_chunk, | ||
| image_sizes = image_sizes_chunk, | ||
| ).logits | ||
|
|
||
| completion_input_ids_chunk = input_ids_chunk[ | ||
| :, -(logits_to_keep + max_left_pad) : | ||
| ] | ||
| logits_chunk = logits_chunk[ | ||
| :, -(logits_to_keep + max_left_pad + 1) :, : | ||
| ] | ||
| logits_chunk = logits_chunk[:, :-1, :] | ||
| else: | ||
| logits_chunk = unwrapped_model( | ||
| input_ids = input_ids_chunk, | ||
| attention_mask = attention_mask_chunk, | ||
| pixel_values = pixel_values_chunk, | ||
| image_grid_thw = image_grid_thw_chunk, | ||
| pixel_attention_mask = pixel_attention_mask_chunk, | ||
| image_sizes = image_sizes_chunk, | ||
| logits_to_keep = logits_to_keep + 1, | ||
| ).logits | ||
|
|
||
| logits_chunk = logits_chunk[:, :-1, :] | ||
| completion_input_ids_chunk = input_ids_chunk[ | ||
| :, -logits_to_keep: | ||
| ] | ||
| # breakpoint() | ||
| logprobs_chunk = chunked_hidden_states_selective_log_softmax( | ||
| logits_chunk, | ||
| lm_head, | ||
| completion_input_ids_chunk, | ||
| chunks = 8, | ||
| logit_scale_multiply = logit_scale_multiply, | ||
| logit_scale_divide = logit_scale_divide, | ||
| logit_softcapping = logit_softcapping, | ||
| temperature = temperature, | ||
| ) | ||
|
|
||
| all_logprobs_list.append(logprobs_chunk) | ||
| logprobs = torch.cat(all_logprobs_list, dim = 0) | ||
| entropies = None | ||
|
|
||
| os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" | ||
| # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
| return logits.detach(), entropies # logps, entropies | ||
|
|
||
| return logprobs.detach(), entropies # logps, entropies | ||
| # input_ids = input_ids[:, -logits_to_keep:] | ||
| # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. | ||
| # See https://github.com/huggingface/trl/issues/2770 | ||
|
|
@@ -678,14 +797,14 @@ def compute_loss( | |
| # ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) | ||
| # else: | ||
| # ref_per_token_logps = None | ||
| ref_hidden_states = inputs.get("ref_per_token_logps", None) | ||
| ref_logps = inputs.get("ref_per_token_logps", None) | ||
| # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 | ||
| # x - x.detach() allows for preserving gradients from x | ||
| advantages = inputs["advantages"] | ||
| # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) | ||
| # per_token_loss = -(per_token_loss - self.beta * per_token_kl) | ||
| # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() | ||
| old_hidden_states = inputs.get("old_per_token_logps", None) | ||
| old_logps = inputs.get("old_per_token_logps", None) | ||
|
Comment on lines
803
to
+807
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In Useful? React with 👍 / 👎. |
||
|
|
||
| input_ids = input_ids[:, -logits_to_keep:] | ||
|
|
||
|
|
@@ -700,6 +819,7 @@ def compute_loss( | |
| if logit_scale_divide is None: | ||
| logit_scale_divide = 0 | ||
|
|
||
| max_left_pad = inputs.get("max_left_pad", 0) | ||
| if per_token_logps is not None: | ||
| if ref_hidden_states is not None: | ||
| ref_hidden_states = ref_hidden_states[ | ||
|
|
@@ -731,6 +851,7 @@ def compute_loss( | |
| max_completion_length = self.args.max_completion_length, | ||
| delta = self.args.delta, | ||
| temperature = self.args.temperature, | ||
| max_left_pad = max_left_pad, | ||
| logit_softcapping = logit_softcapping, | ||
| logit_scale_multiply = logit_scale_multiply, | ||
| logit_scale_divide = logit_scale_divide, | ||
|
|
@@ -751,8 +872,8 @@ def compute_loss( | |
| logits_to_keep = logits_to_keep, | ||
| completion_mask = completion_mask, | ||
| advantages = advantages, | ||
| old_hidden_states = old_hidden_states, | ||
| ref_hidden_states = ref_hidden_states, | ||
| old_logps = old_logps, | ||
| ref_logps = ref_logps, | ||
| n_chunks = self.args.unsloth_num_chunks, | ||
| loss_type = self.args.loss_type, | ||
| importance_sampling_level = self.importance_sampling_level, | ||
|
|
@@ -761,6 +882,7 @@ def compute_loss( | |
| max_completion_length = self.args.max_completion_length, | ||
| delta = self.args.delta, | ||
| temperature = self.args.temperature, | ||
| max_left_pad = max_left_pad, | ||
| logit_softcapping = logit_softcapping, | ||
| logit_scale_multiply = logit_scale_multiply, | ||
| logit_scale_divide = logit_scale_divide, | ||
|
|
@@ -797,7 +919,11 @@ def compute_loss( | |
| self._metrics["completion_length"].append(completion_length.item()) | ||
| self._metrics["kl"].append(mean_kl.item()) | ||
|
|
||
| if self.use_vllm and delta is not None: | ||
| if ( | ||
| self.use_vllm | ||
| and delta is not None | ||
| and getattr(self, "vllm_importance_sampling_correction", False) | ||
| ): | ||
| mean_delta = ( | ||
| torch.mean(delta) | ||
| if delta.numel() > 0 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new block in
_generate_and_score_completionsbuildsoutput["max_left_pad"]usingsampling_per_token_logps.shapebefore that variable is guaranteed to exist. In the common non-vLLM path (or when importance sampling correction is disabled),sampling_per_token_logpsis never defined, so hitting this code will raise aNameErrorbefore any completions are returned, breaking GRPO training without vLLM. The guard below only protects the later assignment, somax_left_padneeds to be gated on the presence of sampling logprobs or sized from another tensor.Useful? React with 👍 / 👎.