Draft [models] feat: Add a modeling patch gen sample for qwen3#424
Draft [models] feat: Add a modeling patch gen sample for qwen3#424piyifan123 wants to merge 5 commits intotransformers50from
Conversation
There was a problem hiding this comment.
Code Review
The pull request introduces a robust code generation framework for patching HuggingFace models, moving away from runtime monkey patching towards a more maintainable AST-based approach. This is a significant improvement for debugging and understanding model modifications. The changes include a patch specification DSL, the core code generator, and a CLI runner. While the overall design is sound, several high-severity issues related to the correctness and reliability of the code generation process have been identified, particularly concerning AST node end line estimation, decorator removal, indentation handling, and the diff utility's external command usage. Addressing these will ensure the generated code is consistently correct and well-formatted.
| unparsed = ast.unparse(node) | ||
| return node.lineno + unparsed.count("\n") | ||
| except Exception: | ||
| return node.lineno + 10 # Rough estimate |
There was a problem hiding this comment.
The fallback return node.lineno + 10 for get_node_end_line is a rough estimate and highly unreliable. This can lead to incorrect code segments being extracted or replaced, potentially corrupting the generated output. A more robust method is needed to accurately determine the end line when end_lineno is not available, perhaps by parsing the unparsed source to find the actual end of the statement or definition.
| # Remove the patch decorator lines if present (handles multi-line decorators) | ||
| source_lines = replacement_source.splitlines() | ||
| filtered_lines = [] | ||
| in_patch_decorator = False | ||
| paren_depth = 0 | ||
| for line in source_lines: | ||
| stripped = line.strip() | ||
| # Start of a patch decorator | ||
| if stripped.startswith("@") and ("override_method" in stripped or "config." in stripped): | ||
| in_patch_decorator = True | ||
| paren_depth = stripped.count("(") - stripped.count(")") | ||
| # If decorator is closed on same line, we're done skipping | ||
| if paren_depth <= 0: | ||
| in_patch_decorator = False | ||
| continue | ||
| # Continuation of multi-line decorator | ||
| if in_patch_decorator: | ||
| paren_depth += stripped.count("(") - stripped.count(")") | ||
| if paren_depth <= 0: | ||
| in_patch_decorator = False | ||
| continue | ||
| filtered_lines.append(line) | ||
| cleaned_replacement_source = "\n".join(filtered_lines) |
There was a problem hiding this comment.
The logic for removing patch decorator lines relies on string manipulation and parenthesis counting, which is fragile. This approach is highly susceptible to breaking with different decorator formats, multi-line decorators, or future Python versions, leading to incorrect generated code. Consider using AST manipulation to remove decorators more robustly, or ensure the decorator parsing is more resilient.
| original_line = source_lines[method_start] | ||
| indent = len(original_line) - len(original_line.lstrip()) | ||
| else: | ||
| indent = 4 # Default class method indent |
There was a problem hiding this comment.
Hardcoding the default class method indentation to 4 can lead to inconsistent formatting in the generated code if the original source or the replacement code uses a different indentation style (e.g., 2 spaces). This impacts the maintainability and readability of the generated output. It would be better to infer the indentation from the original class or method being replaced.
| stripped = line[preserved_indent:] if len(line) >= preserved_indent else line.lstrip() | ||
| indented_preserved_lines.append(" " * indent + stripped) |
There was a problem hiding this comment.
The re-indentation logic stripped = line[preserved_indent:] if len(line) >= preserved_indent else line.lstrip() might not handle all edge cases correctly. Specifically, if preserved_indent is greater than the line's actual indentation but less than its length, it could strip too much. If preserved_indent is greater than the line length, line[preserved_indent:] would result in an empty string, which might not be the desired behavior for lines that should retain some content. This could lead to malformed code in the output.
|
|
||
| # 3. Process ALL module-level nodes in their original order | ||
| # This preserves the exact structure of the original file | ||
| for node in self.source_ast.body: |
There was a problem hiding this comment.
The patches argument to self._generate_class_source is passed as an empty dictionary {}. This is a logical error, as _generate_class_source expects relevant patches to be applied. This means that method overrides or other class-level patches might not be correctly applied during the generation process. It should likely be self.config.get_method_overrides() or a filtered set of patches relevant to the current class.
| str(generated_path), | ||
| ] |
There was a problem hiding this comment.
The arguments for the delta command appear to be incorrect. --file-modified-label typically expects a label string, not a file path, and delta usually takes two file paths for comparison. This will likely cause the delta command to fail or produce unexpected output, breaking the diff functionality when delta is installed.
| original_tmp_path.unlink(missing_ok=True) | ||
|
|
There was a problem hiding this comment.
The cleanup of original_tmp_path using unlink(missing_ok=True) is not guaranteed to execute if subprocess.run(cmd) raises an exception. This can leave temporary files behind, leading to unnecessary disk usage. It's best practice to wrap such cleanup operations in a finally block to ensure they always run.
|
|
||
|
|
||
| @config.replace_function("apply_rotary_pos_emb", description="Use LigerKernel rotary embedding") | ||
| def apply_rotary_pos_emb_liger( |
There was a problem hiding this comment.
Do you want to manage these in a common plance?
|
|
||
| config.patches.append( | ||
| create_patch_from_external( | ||
| target="Qwen3MLP", |
There was a problem hiding this comment.
Nit: The param name is a little bit confusing. source & target feels like changing from source to target.
| config.patches.append( | ||
| create_patch_from_external( | ||
| target="Qwen3MLP", | ||
| source_module="liger_kernel.transformers.swiglu", |
There was a problem hiding this comment.
Is it possible to actually import it instead of using a string?
| hidden_states = outputs.last_hidden_state | ||
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | ||
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | ||
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
There was a problem hiding this comment.
The logic to calculate loss here may changed if we use liger-kernel
|
|
||
| hidden_states = inputs_embeds | ||
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | ||
|
|
There was a problem hiding this comment.
How to patch Qwen3Model.forward for SP case?
| ├── patch_spec.py # Patch specification DSL | ||
| ├── codegen.py # AST-based code generator | ||
| ├── run_codegen.py # CLI runner script | ||
| ├── patches/ |
There was a problem hiding this comment.
Can we still keep the position of the patch code and the generated code in the model dir
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includemisc,ci,config,docs,data,dist,omni,logging,model,optim,ckpt,release,task,perf,ops,parallel,like[ci, data, model]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][parallel, model] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always