[Draft] QAD algorithms - SmoothLAQ, AdaRound#1150
Conversation
Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com> Respect narrow_range in IntCastSTEFunction for scale learning Signed-off-by: realAsma <akuriparambi@nvidia.com> Made-with: Cursor Fix scale_after_dequant for non-NVFP4 quantizers Signed-off-by: realAsma <akuriparambi@nvidia.com> Made-with: Cursor
…quant - Broaden local_hessian_calibrate to handle INT block quant (not just FP4) - Support mse, local_hessian, and max methods in scale_after_dequant - Add _convert_to_static_block_quantizers helper for max_calibrate path Signed-off-by: realAsma <akuriparambi@nvidia.com> Made-with: Cursor
- Refactor utils.py to support local JSONL datasets (files/dirs) via --dataset and tokenized dataset caching via --dataset_cache_path - Normalize Daring-Anteater conversations to standard messages format - Add distributed-aware tokenization with per-rank sharding and merging - Wire dataset_cache_path through launch.sh and main.py DataArguments - Update README with local JSONL dataset and caching examples - Remove unused NVFP4StaticQuantizer import in model_calib.py - Fix import ordering in vllm plugin and test_quantize_api Made-with: Cursor
Cast _per_block_scale and _per_tensor_scale to float32 before scale computation, then cast the final scale back to the input dtype. This prevents mixed-precision issues during fake quantization with learned scales. Made-with: Cursor
Signed-off-by: realAsma <akuriparambi@nvidia.com> Made-with: Cursor
- Simplify KDTrainer by assigning compute_loss_func in __init__ instead of overriding train() and compute_loss(); remove stale methods. - Fix QADTrainer MRO to (KDTrainer, QATTrainer) so KD loss takes precedence during training. - Guard modelopt state restore with `not is_quantized()` to avoid double-restoring an already-quantized model. - Use tq.to() instead of tq.to_empty() in FSDP2 prepare patch. - Remove unnecessary float32 dtype casts in StaticBlockScaleQuantizer. - Add attn_implementation arg, fineweb_edu pretraining dataset, fp32 model loading, and pretrain tokenizer to LLM QAT example. Made-with: Cursor
Signed-off-by: realAsma <akuriparambi@nvidia.com>
…r rename - Rename _fp4_cast to _cast_ste, supporting both FP4 and INT cast - Fix NVFP4StaticQuantizer -> StaticBlockScaleQuantizer references in adaround - Add NVFP4StaticAdaRoundQuantizer import/restore in conversion.py - Fix test config to use fp8_scale_sweep for matching calibration Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…eight support - Add dist_loss_weight, beta_start, beta_end, freeze_weight to AdaRoundConfig - Store hyperparams on NVFP4StaticAdaRoundQuantizer via from_nvfp4_quantizer - Auto-detect adaround quantizers in QATTrainer, add annealed dist_loss to compute_loss - Detach floor cast in _cast_ste when freeze_weight=True (only round_logits get grads) - Add trainer tests: QATTrainer (with/without adaround), QADTrainer with adaround+KD Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… kernel Convert learnable DTensor parameters (_per_block_scale, round_logits) to local tensors inside the quantizer computation path, matching the existing to_local() conversion for inputs. Also fix fp4_step_size triton kernel to preserve input dtype instead of hardcoding float32. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…"loss" Remove compute_loss override entirely. Instead, compute dist_loss after super().training_step() and do a separate accelerator.backward() so gradients accumulate naturally. The Trainer's auto-logged "loss" now reflects base_loss (comparable with non-adaround jobs), while adaround/dist_loss and adaround/beta appear as separate TensorBoard scalars. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…izer Signed-off-by: realAsma <akuriparambi@nvidia.com> Made-with: Cursor
…en param support Move training-time knobs (beta_start, beta_end, dist_loss_weight, temperature) out of AdaRoundConfig and NVFP4StaticAdaRoundQuantizer into a new AdaRoundTrainingArguments dataclass on QATTrainer. Add trainable_params and frozen_params (fnmatch patterns) to QuantizationArguments so QATTrainer can configure requires_grad before optimizer creation. Remove redundant detach logic from _cast_ste. Made-with: Cursor
…assthrough Adaround-specific args (beta_start, beta_end, dist_loss_weight, temperature, freeze_weights) are passed through to main.py unchanged, so they don't need explicit parsing. Unknown args now collect into EXTRA_ARGS instead of erroring, making the script extensible without modification. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
- Fix adaround metrics not appearing in logs: override QATTrainer.log() to merge metrics before the callback pipeline (ProgressCallback was printing before _AdaRoundAuxCallback could inject them via on_log) - Fix _liger_loss_func passing unexpected arg to zero-arg _compute() - Update test_adaround_trainer to verify adaround quantizers and logged metrics directly instead of referencing removed _adaround_aux_callback Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Introduce a post-optimizer-step callback that computes per-weight MSE between quantized and original weights and applies independent manual SGD steps to push weights toward quantized grid points. The regularization coefficient is linearly annealed from qerr_coeff_start to qerr_coeff_stop over training. Qerr and AdaRound are mutually exclusive (ValueError if both set). Made-with: Cursor
Signed-off-by: realAsma <akuriparambi@nvidia.com> Made-with: Cursor
PyYAML safe_load parses bare exponential notation (e.g. 1e-5) as strings instead of floats. This caused a TypeError in the LR scheduler when base_lr (a string) was multiplied by the lambda output (a float). Made-with: Cursor
…s_coeff - Qerr MSE is now always computed and reported (qerr/mse, qerr/coeff) every step. Default coefficients are 0 (monitor-only, no gradient applied). Set non-zero qerr_coeff_start/stop to enable active regularization. - Rename dist_loss_weight -> dist_loss_coeff for consistency. - Parse QuantErrorTrainingArguments in main.py and pass qerr_args to trainer. - Extract _compute_mse helper in _QuantErrorAuxCallback. - Replace mutual exclusion error with if/elif (adaround takes priority). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…t dims Reshape weight to [num_blocks, block_size] before calling _cast_ste in _compute_mse, matching what the quantizer forward pass does. Without this, MLP layers with intermediate_size=6144 trigger Triton's "arange's range must be a power of 2" error. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…heck Add --qerr_reduction flag to QuantErrorTrainingArguments with 'sum' as default. The previous .mean() reduction produced per-element gradients too small to move quantization error. The .sum() reduction aggregates across all weight elements for stronger gradients. Logged metric key now reflects the reduction used (qerr/sum or qerr/mean). Also restores the adaround_args/qerr_args mutual exclusivity ValueError. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
HfArgumentParser creates default instances of both dataclasses, so both are always non-None. The real exclusivity is handled by the if/elif in _setup_training where adaround takes priority. Remove the now-incorrect ValueError and its corresponding test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…izer The squared error `(q_weight - weight) ** 2` was allowing gradients to flow back through the quantizer's forward pass, corrupting weight updates. Adding `.detach()` ensures qerr only produces gradients w.r.t. the original weight, which is the correct STE-like behavior. Also skip the adaround aux step when round_logits are frozen (e.g. in the freeze-round paradigm) to avoid unnecessary computation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…tion LSQ computes s*Q(x/s) without pre-dividing weights, allowing the quantization grid to adapt as learned scales change during training. Refactors shared helpers out of scale_after_dequant for reuse. Includes lint/format fixes from pre-commit. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ithms Rename scale_after_dequant -> smooth_lsq to match paper terminology. Add two new FP4 weight quantization algorithms: LAQ (learns a_max, derives s = a_max/Q_max, no weight pre-division) and SmoothLAQ (learns a_max, weights pre-divided by a_max, forward: Q_STE(w_a * Q_max) * s). Also cast LSQ/LAQ division to fp32 for precision. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- New _compute_laq_params() returns raw amaxes instead of roundtripping through per_block_scale (bad values default to Q_max so scale=1) - enable_laq/enable_smooth_laq accept amax directly, bypass _enable_learnable_scales; rename _amax_param -> _amax_learnt - SmoothLAQ no longer pre-divides weights; stores _amax_frozen buffer and divides by frozen scale in forward (optimizer updates original w) - New _quantize_scale() helper shared across all 4 learnable algorithms for FP8 scale quantization (reused for both dequant and frozen scales) - Unified _fake_quantize: compute scale -> quant input -> cast -> dequant - amax property raises RuntimeError for SmoothLAQ (ambiguous) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the deprecated _scale_after_dequant flag with _smooth_lsq and generalize AdaRound to work with any learnable-scale init algorithm (smooth_lsq, lsq, laq, smooth_laq) or plain calibration (max, mse, local_hessian) which auto-converts via smooth_lsq. - Rename _scale_after_dequant -> _smooth_lsq across quantizer, calib, conversion, config, and tests - Replace smooth_lsq_args with init_algorithm in AdaRoundConfig - Add _compute_weight_scaled() for per-mode weight scaling - Add parametrized test_adaround_with_init_algorithms covering all 7 algos Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extract _amax_to_scale(amax, max_bound) that computes scale = amax/max_bound with zero-guard, replacing duplicated scale computation across LAQ, SmoothLAQ, LSQ, SmoothLSQ in _fake_quantize, _compute_block_scales, _compute_laq_params, and _compute_weight_scaled. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis pull request introduces learnable-scale quantization algorithms (smooth_lsq, lsq, laq, smooth_laq, adaround), new training infrastructure for adaround and quantization-error regularization, a post-training quantization script, updated dataset handling for QAT examples, distillation loss refactoring, and corresponding test coverage. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant QAT as QAT Trainer
participant Model as Model
participant Quantizer as StaticBlockScale<br/>Quantizer
participant AdaroundCB as AdaRound<br/>Callback
User->>QAT: Initialize with adaround_args
QAT->>Model: Check for NVFP4StaticAdaRoundQuantizer
Model->>Quantizer: Enable AdaRound mode
Quantizer->>Quantizer: Create trainable round_logits
loop Training Step
QAT->>Model: Forward pass
Model->>Quantizer: Fake quantize with learned rounding
Quantizer-->>Model: Rounded logits
QAT->>AdaroundCB: Compute dist_loss from round_logits
AdaroundCB->>Quantizer: Backprop to round_logits
AdaroundCB->>AdaroundCB: Apply custom update with beta annealing
AdaroundCB-->>QAT: Log adaround/dist_loss, beta
end
User->>QAT: Save model
QAT->>Model: Export quantized weights with learned rounding
sequenceDiagram
participant Client as Caller
participant QAT as QATTrainer
participant Loss as LMLogitsLoss
participant Fused as LigerFusedLinearJSD<br/>(optional)
Client->>QAT: Initialize with distill_config and use_liger_kernel
QAT->>QAT: Convert model to distillation form
alt use_liger_kernel enabled
QAT->>QAT: Check lm_head compatibility
QAT->>Fused: Create fused JSD loss
QAT->>QAT: Set compute_loss_func to fused path
else use_liger_kernel disabled
QAT->>Loss: Create compute_kd_loss with masking
QAT->>QAT: Set compute_loss_func to KD path
end
loop Training Step
Client->>QAT: Run training step
QAT->>QAT: Forward student and teacher
QAT->>Loss: Compute per-token KL-div or JSD
Loss-->>QAT: Unreduced per-token losses (B*S,)
QAT->>QAT: Reduce and backprop
end
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 12
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_qat/create_ptq.py`:
- Around line 137-145: The CLI defaults for dataset arguments are non-portable
and can cause stale cache reuse; update the p.add_argument call for "--dataset"
to remove the hardcoded default and instead make the arg required (e.g.,
p.add_argument("--dataset", type=str, required=True, help=...)), and change the
"--dataset_cache_path" default from "dataset_cache" to an empty string (or None)
so cache reuse is explicit (e.g., p.add_argument("--dataset_cache_path",
type=str, default="", help=...)); adjust any downstream code that assumes a
non-empty cache path to handle empty/None appropriately before returning
p.parse_args().
In `@examples/llm_qat/launch.sh`:
- Around line 53-55: The current launch.sh silently collects any unrecognized
token into EXTRA_ARGS which gets forwarded to main.py (which uses
HfArgumentParser.parse_args_into_dataclasses() and will crash on unknown args);
update the script to stop appending arbitrary arguments in the default *) branch
of the case: either remove the EXTRA_ARGS mechanism entirely, or
validate/whitelist tokens before appending (e.g., check $1 against a set of
allowed flags) and otherwise emit a clear error and exit; reference EXTRA_ARGS,
the default *) case in launch.sh, and main.py /
HfArgumentParser.parse_args_into_dataclasses() when making the change.
In `@examples/llm_qat/utils.py`:
- Around line 254-255: The code incorrectly treats pad_token_id == 0 as missing
by using "pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id"; change
the logic to detect only None so a zero ID is preserved (e.g., check
tokenizer.pad_token_id is None or use a ternary that uses tokenizer.eos_token_id
only when pad_token_id is None). Update the assignment around pad_token and the
subsequent None-check to use explicit None checks referencing
tokenizer.pad_token_id and tokenizer.eos_token_id so valid 0 pad IDs are not
overridden.
- Around line 96-104: The cache key currently omits tokenizer identity causing
tokenization mismatches; modify _build_cache_path to accept a tokenizer
identifier (e.g., tokenizer.name_or_path) and include it in the hashed string,
and then thread that identifier through the caller _load_cached_dataset so it
passes tokenizer.name_or_path into _build_cache_path when constructing the cache
path; ensure the identifier is stable (name_or_path) and included in the same
f"{...}" input used to compute cache_key.
In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 56-73: Ensure the fused LIGER JSD path is only enabled when there
is exactly one distillation loss of the expected LMLogitsLoss type: in
_setup_liger_fused_loss check that model._layers_to_loss contains a single entry
and that that loss object is the LMLogitsLoss (or has the specific attributes
_temperature and logits inputs) before setting use_liger_kernel=True and
compute_loss_func=_liger_loss_func; otherwise set use_liger_kernel=False and
compute_loss_func=compute_kd_loss. Apply the same exact precondition check in
the other setup block mentioned (the similar code around lines 91-124) so
_liger_loss_func is never selected when multiple or different loss types are
configured.
In `@modelopt/torch/opt/plugins/transformers.py`:
- Around line 417-431: The patched forward assignment in _fsdp_forward_redirect
must be undone even on exceptions: wrap the invocation of fsdp_module in a
try/finally so fsdp_module.forward is always reset to original_forward; perform
the assignment fsdp_module.forward = wrapped_forward, call fsdp_module(...)
inside try, and in finally restore fsdp_module.forward = original_forward. Also
replace the string sentinel argument ("_fsdp_redirect") with a dummy tensor
created on the module's device/dtype (e.g., a torch.zeros tensor derived from a
model parameter or buffer) so forward pre-hooks receive a properly-typed input;
keep the inner wrapped_forward returning fn() unchanged.
In `@modelopt/torch/quantization/config.py`:
- Around line 268-279: Add the new INT3_BLOCKWISE_WEIGHT_ONLY_CFG preset to the
string-based selector by adding an entry for "INT3_BLOCKWISE_WEIGHT_ONLY_CFG" in
the choices mapping (the same mapping that currently lists other presets) so
that the config can be selected by name; update the choices dict/switch where
presets are registered (reference symbol: choices) to include a key pointing to
INT3_BLOCKWISE_WEIGHT_ONLY_CFG so examples that validate against
modelopt.torch.quantization.config.choices can select it.
In `@modelopt/torch/quantization/conversion.py`:
- Around line 132-140: The current sequential checks can call
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer(module) without ensuring
module is first a StaticBlockScaleQuantizer; update the restoration logic in
conversion.py so that when state["_is_nvfp4_static_adaround_quantizer"] is true
you first ensure or convert the module to a StaticBlockScaleQuantizer (i.e., run
StaticBlockScaleQuantizer.from_tensor_quantizer(module) if needed) or explicitly
validate that state["_is_nvfp4_static_quantizer"] is also true and raise/handle
a clear error; specifically touch the block that references
StaticBlockScaleQuantizer, NVFP4StaticAdaRoundQuantizer, from_nvfp4_quantizer,
and set_from_modelopt_state to enforce the invariant or perform the conversion
before calling from_nvfp4_quantizer.
In `@modelopt/torch/quantization/nn/modules/tensor_quantizer.py`:
- Around line 1555-1581: The conversion method from_nvfp4_quantizer should
reject non-FP4/ non-NVFP4 block quantizers instead of blindly mutating any
StaticBlockScaleQuantizer; add a guard near the start of
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer that inspects the source
quantizer's format/block-type (e.g., a property like block_format / is_fp4 /
nvfp4 flag on tq) and raise/assert if it is not an FP4/NVFP4 config, and do the
same check in the other related conversion methods mentioned (the other
from_nvfp4_quantizer conversion blocks at the ranges you noted) so AdaRound
remains FP4-only unless you implement a separate INT-specific AdaRound path.
In `@modelopt/torch/quantization/plugins/transformers_trainer.py`:
- Around line 308-312: The trainer's _setup_adaround() registers the AdaRound
callback but never propagates AdaRoundTrainingArguments.temperature to the
quantizers, so non-default temperatures are ignored; update _setup_adaround() to
iterate the model's quantizers (e.g., instances of NVFP4StaticAdaRoundQuantizer)
and set their temperature property from self.args.ada_round_args.temperature (or
self.ada_rounding_args / AdaRoundTrainingArguments.temperature where stored)
before adding the _AdaRoundAuxCallback, ensuring each
NVFP4StaticAdaRoundQuantizer.temperature (or equivalent attribute) is assigned
the trainer's temperature value so AdaRound uses the configured temperature.
- Around line 676-695: The branch in _compute_mse incorrectly uses
quantizer._cast_ste for any StaticBlockScaleQuantizer that has no learnable
modes, which incorrectly drops block/tensor scales for calibrated
(non-pre-divided) weights; change the condition so _cast_ste is used only when
the quantizer is a pre-divided/smooth_lsq case (i.e., when quantizer._smooth_lsq
is True), and for other StaticBlockScaleQuantizer instances call
quantizer(weight) instead; preserve the existing reshape logic using
quantizer._block_reshape_size and then reshape back to orig_shape before
computing sq_err.
- Around line 650-672: The current registration loop (_weight_entries) includes
quantized weights even when they are frozen or not in the optimizer, causing
qerr to backprop through tensors with no grads; update the loop that builds
self._weight_entries (the block using weight_attr_names, quantizer_attr_names,
weight_quantizer and is_enabled) to only append weights that are
torch.nn.Parameter AND weight.requires_grad is True AND the weight's id is
present in the optimizer param mapping (pid_to_group); similarly ensure the
later population of self._param_group_idx and self._multiplier only iterates
over these filtered _weight_entries (so id(weight) exists in pid_to_group) to
avoid creating entries for optimizer-less/frozen weights — apply the same guard
to the analogous registration block elsewhere that performs the same work.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b794a3ad-578d-4f61-90b1-5c3c9601564c
📒 Files selected for processing (23)
examples/llm_qat/README.mdexamples/llm_qat/create_ptq.pyexamples/llm_qat/launch.shexamples/llm_qat/main.pyexamples/llm_qat/utils.pymodelopt/torch/distill/plugins/huggingface.pymodelopt/torch/opt/plugins/huggingface.pymodelopt/torch/opt/plugins/transformers.pymodelopt/torch/quantization/config.pymodelopt/torch/quantization/conversion.pymodelopt/torch/quantization/mode.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/model_quant.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/plugins/transformers_trainer.pymodelopt/torch/quantization/tensor_quant.pymodelopt/torch/quantization/triton/fp4_kernel.pymodelopt/torch/quantization/triton/fp4_kernel_hopper.pytests/gpu/torch/quantization/test_adaround_trainer.pytests/gpu/torch/quantization/test_liger_loss.pytests/gpu/torch/quantization/test_qerr_trainer.pytests/gpu/torch/quantization/test_quantize_cuda.pytests/unit/torch/opt/plugins/test_lr_config.py
💤 Files with no reviewable changes (1)
- modelopt/torch/opt/plugins/huggingface.py
| p.add_argument( | ||
| "--dataset", | ||
| type=str, | ||
| default="/home/scratch.akuriparambi_coreai/datasets/qat_blend_sft/blend_sft.jsonl", | ||
| ) | ||
| p.add_argument("--eval_size", type=int, default=0) | ||
| p.add_argument("--train_size", type=int, default=0) | ||
| p.add_argument("--dataset_cache_path", type=str, default="dataset_cache") | ||
| return p.parse_args() |
There was a problem hiding this comment.
Use portable, non-stale defaults for the dataset arguments.
The current --dataset default only exists on one machine, and the fixed "dataset_cache" default will silently reuse tokenized data from previous runs with different inputs. Making --dataset required and leaving --dataset_cache_path empty avoids both footguns.
Suggested fix
p.add_argument(
"--dataset",
type=str,
- default="/home/scratch.akuriparambi_coreai/datasets/qat_blend_sft/blend_sft.jsonl",
+ required=True,
)
@@
- p.add_argument("--dataset_cache_path", type=str, default="dataset_cache")
+ p.add_argument("--dataset_cache_path", type=str, default="")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_qat/create_ptq.py` around lines 137 - 145, The CLI defaults for
dataset arguments are non-portable and can cause stale cache reuse; update the
p.add_argument call for "--dataset" to remove the hardcoded default and instead
make the arg required (e.g., p.add_argument("--dataset", type=str,
required=True, help=...)), and change the "--dataset_cache_path" default from
"dataset_cache" to an empty string (or None) so cache reuse is explicit (e.g.,
p.add_argument("--dataset_cache_path", type=str, default="", help=...)); adjust
any downstream code that assumes a non-empty cache path to handle empty/None
appropriately before returning p.parse_args().
| *) | ||
| >&2 printf "Error: Invalid argument ${1#*=}\n" | ||
| exit 1 | ||
| EXTRA_ARGS="${EXTRA_ARGS:-} $1" | ||
| ;; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify main.py uses parse_args_into_dataclasses without remainder handling
rg -n "parse_args_into_dataclasses|parse_known_args" examples/llm_qat/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 114
🏁 Script executed:
# Get context around line 189 in main.py to understand HfArgumentParser usage
head -200 examples/llm_qat/main.py | tail -30Repository: NVIDIA/Model-Optimizer
Length of output: 996
🏁 Script executed:
# Find where EXTRA_ARGS is used in launch.sh command invocation
rg -n "EXTRA_ARGS" examples/llm_qat/launch.sh -A 2 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 429
🏁 Script executed:
# Check the imports and dataclass setup in main.py
head -50 examples/llm_qat/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2113
🏁 Script executed:
# Check if there's any error handling around parse_args_into_dataclasses
grep -A 10 -B 5 "parse_args_into_dataclasses" examples/llm_qat/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 698
🏁 Script executed:
# Search for any tests or usage of EXTRA_ARGS in launch.sh to see if it's actually expected to work
grep -n "EXTRA_ARGS\|unknown\|unrecognized" examples/llm_qat/launch.shRepository: NVIDIA/Model-Optimizer
Length of output: 207
🏁 Script executed:
# Check if there's any documentation or comments about handling extra arguments
grep -n "arg\|option" examples/llm_qat/launch.sh | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 119
🏁 Script executed:
# Check the dataclass definitions to see what arguments are expected
grep -n "class.*Arguments.*:" examples/llm_qat/main.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 208
🏁 Script executed:
# Look for field definitions in the dataclasses to understand what arguments are supported
grep -A 2 "field(" examples/llm_qat/main.py | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 1233
🏁 Script executed:
# Check if there's any documentation or comments about EXTRA_ARGS in launch.sh
head -100 examples/llm_qat/launch.sh | tail -50Repository: NVIDIA/Model-Optimizer
Length of output: 1606
🏁 Script executed:
# Check if dataclasses have any special field configuration for handling extra args
grep -n "allow_extra_keys\|extra_keys\|allow_extra" examples/llm_qat/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look for any sys.argv manipulation or argument filtering before parse_args_into_dataclasses
grep -B 20 "parse_args_into_dataclasses" examples/llm_qat/main.py | grep -n "sys\|argv\|args"Repository: NVIDIA/Model-Optimizer
Length of output: 213
🏁 Script executed:
# Check the actual command being executed to see the full context
sed -n '170,190p' examples/llm_qat/launch.shRepository: NVIDIA/Model-Optimizer
Length of output: 570
🏁 Script executed:
# Check if there are any example usage or tests showing how EXTRA_ARGS is meant to be used
find examples/llm_qat -type f \( -name "*.md" -o -name "README*" -o -name "*.txt" \) | xargs grep -l "EXTRA_ARGS" 2>/dev/nullRepository: NVIDIA/Model-Optimizer
Length of output: 50
🏁 Script executed:
# Look for any documentation about the script usage
head -50 examples/llm_qat/launch.shRepository: NVIDIA/Model-Optimizer
Length of output: 3424
🏁 Script executed:
# Verify the exact HfArgumentParser import and check if there's any custom handling
grep -A 5 "HfArgumentParser" examples/llm_qat/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 240
🏁 Script executed:
# Check the transformers library version to understand HfArgumentParser behavior
grep -r "transformers" examples/llm_qat/ | grep -E "requirements|setup|version" | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Search for any tests or examples in the repo that show EXTRA_ARGS being used
find . -type f -name "*.sh" -o -name "*.md" | xargs grep -l "EXTRA_ARGS" 2>/dev/null | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 94
🏁 Script executed:
# Check if there's any try-catch or error handling in the train function
sed -n '185,210p' examples/llm_qat/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1243
Unrecognized arguments forwarded to main.py will cause a runtime crash.
The EXTRA_ARGS mechanism (line 54) silently collects unrecognized flags and appends them to the Python command (line 181), but main.py uses HfArgumentParser.parse_args_into_dataclasses() which calls argparse.parse_args() internally—not parse_known_args(). Any unrecognized argument will raise an error at runtime.
Either:
- Switch
main.pyto useparse_known_args()and handle the remainder, or - Validate that
EXTRA_ARGSonly contains arguments recognized bymain.py, or - Remove the
EXTRA_ARGSmechanism if it's not actively needed
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_qat/launch.sh` around lines 53 - 55, The current launch.sh
silently collects any unrecognized token into EXTRA_ARGS which gets forwarded to
main.py (which uses HfArgumentParser.parse_args_into_dataclasses() and will
crash on unknown args); update the script to stop appending arbitrary arguments
in the default *) branch of the case: either remove the EXTRA_ARGS mechanism
entirely, or validate/whitelist tokens before appending (e.g., check $1 against
a set of allowed flags) and otherwise emit a clear error and exit; reference
EXTRA_ARGS, the default *) case in launch.sh, and main.py /
HfArgumentParser.parse_args_into_dataclasses() when making the change.
| def _build_cache_path( | ||
| dataset: str, dataset_cache_path: str, max_length: int, train_size: int, eval_size: int | ||
| ) -> str: | ||
| if dataset_cache_path: | ||
| return dataset_cache_path | ||
| cache_key = hashlib.sha1( | ||
| f"{dataset}|{max_length}|{train_size}|{eval_size}".encode() | ||
| ).hexdigest()[:12] | ||
| return os.path.join(tempfile.gettempdir(), f"llm_qat_tokenized_{cache_key}") |
There was a problem hiding this comment.
Include tokenizer identity in the derived cache path.
The hash only uses dataset/length/split sizes. A second run with a different tokenizer but the same dataset will silently reload the first run's cached token IDs and labels.
Suggested fix
def _build_cache_path(
- dataset: str, dataset_cache_path: str, max_length: int, train_size: int, eval_size: int
+ dataset: str,
+ dataset_cache_path: str,
+ tokenizer_id: str,
+ max_length: int,
+ train_size: int,
+ eval_size: int,
) -> str:
if dataset_cache_path:
return dataset_cache_path
cache_key = hashlib.sha1(
- f"{dataset}|{max_length}|{train_size}|{eval_size}".encode()
+ f"{dataset}|{tokenizer_id}|{max_length}|{train_size}|{eval_size}".encode()
).hexdigest()[:12]Then thread a stable tokenizer identifier through _load_cached_dataset(), e.g. tokenizer.name_or_path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_qat/utils.py` around lines 96 - 104, The cache key currently
omits tokenizer identity causing tokenization mismatches; modify
_build_cache_path to accept a tokenizer identifier (e.g.,
tokenizer.name_or_path) and include it in the hashed string, and then thread
that identifier through the caller _load_cached_dataset so it passes
tokenizer.name_or_path into _build_cache_path when constructing the cache path;
ensure the identifier is stable (name_or_path) and included in the same f"{...}"
input used to compute cache_key.
| pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id | ||
| if pad_token is None: |
There was a problem hiding this comment.
Don't treat pad_token_id == 0 as missing.
0 is a valid token ID. Using or here falls through to eos_token_id, so some tokenizers will pad with the wrong token.
Suggested fix
- pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
+ pad_token = (
+ tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None
+ else tokenizer.eos_token_id
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id | |
| if pad_token is None: | |
| pad_token = ( | |
| tokenizer.pad_token_id | |
| if tokenizer.pad_token_id is not None | |
| else tokenizer.eos_token_id | |
| ) | |
| if pad_token is None: |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_qat/utils.py` around lines 254 - 255, The code incorrectly
treats pad_token_id == 0 as missing by using "pad_token = tokenizer.pad_token_id
or tokenizer.eos_token_id"; change the logic to detect only None so a zero ID is
preserved (e.g., check tokenizer.pad_token_id is None or use a ternary that uses
tokenizer.eos_token_id only when pad_token_id is None). Update the assignment
around pad_token and the subsequent None-check to use explicit None checks
referencing tokenizer.pad_token_id and tokenizer.eos_token_id so valid 0 pad IDs
are not overridden.
| def _setup_liger_fused_loss(self): | ||
| """Set up fused JSD for KD. | ||
|
|
||
| Args: | ||
| model: The model to compute loss for. | ||
| inputs: The inputs to the model. | ||
| No-op when called from ModelOptHFTrainer.__init__ (teacher not yet created). | ||
| Re-called from KDTrainer.__init__ after _convert_to_distillation_model(). | ||
| """ | ||
| if not model.training: | ||
| _compute_loss_func = self.compute_loss_func | ||
| self.compute_loss_func = None | ||
| model = self.accelerator.unwrap_model(self.model) | ||
| if not hasattr(model, "_teacher_model"): | ||
| return | ||
| teacher = model._teacher_model | ||
| if not hasattr(model, "lm_head") or not hasattr(teacher, "lm_head"): | ||
| self.use_liger_kernel = False | ||
| self.compute_loss_func = self.compute_kd_loss | ||
| return | ||
|
|
||
| loss_fn = next(iter(model._layers_to_loss.values())) | ||
| self._liger_temperature = getattr(loss_fn, "_temperature", 1.0) | ||
| self.compute_loss_func = self._liger_loss_func |
There was a problem hiding this comment.
Only enable fused KD for the single-LMLogitsLoss case.
This code reads just the first entry from model._layers_to_loss, and _liger_loss_func() always computes one fused JSD term. If a caller configures multiple distillation losses, or a different criterion type, use_liger_kernel=True silently changes the objective instead of reproducing compute_kd_loss(). Please validate that precondition here and fall back to compute_kd_loss() otherwise.
Also applies to: 91-124
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/distill/plugins/huggingface.py` around lines 56 - 73, Ensure
the fused LIGER JSD path is only enabled when there is exactly one distillation
loss of the expected LMLogitsLoss type: in _setup_liger_fused_loss check that
model._layers_to_loss contains a single entry and that that loss object is the
LMLogitsLoss (or has the specific attributes _temperature and logits inputs)
before setting use_liger_kernel=True and compute_loss_func=_liger_loss_func;
otherwise set use_liger_kernel=False and compute_loss_func=compute_kd_loss.
Apply the same exact precondition check in the other setup block mentioned (the
similar code around lines 91-124) so _liger_loss_func is never selected when
multiple or different loss types are configured.
| if state.get("_is_nvfp4_static_quantizer") and not isinstance( | ||
| module, NVFP4StaticQuantizer | ||
| module, StaticBlockScaleQuantizer | ||
| ): | ||
| NVFP4StaticQuantizer.from_tensor_quantizer(module) | ||
| StaticBlockScaleQuantizer.from_tensor_quantizer(module) | ||
| if state.get("_is_nvfp4_static_adaround_quantizer") and not isinstance( | ||
| module, NVFP4StaticAdaRoundQuantizer | ||
| ): | ||
| NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer(module) | ||
| module.set_from_modelopt_state(quantizer_state_dict[name]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if _is_nvfp4_static_quantizer is set when creating AdaRound quantizer
ast-grep --pattern 'from_nvfp4_quantizer($$$) {
$$$
}'Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
#!/bin/bash
# Search for where the flags are set in the codebase
echo "=== Searching for _is_nvfp4_static_quantizer flag assignments ==="
rg "_is_nvfp4_static_quantizer" -n --context 3
echo ""
echo "=== Searching for _is_nvfp4_static_adaround_quantizer flag assignments ==="
rg "_is_nvfp4_static_adaround_quantizer" -n --context 3Repository: NVIDIA/Model-Optimizer
Length of output: 2995
🏁 Script executed:
#!/bin/bash
# Find the implementation of from_nvfp4_quantizer and from_tensor_quantizer
echo "=== Searching for from_nvfp4_quantizer method ==="
rg "def from_nvfp4_quantizer" -A 10 -B 2
echo ""
echo "=== Searching for from_tensor_quantizer method ==="
rg "def from_tensor_quantizer" -A 10 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 2674
🏁 Script executed:
#!/bin/bash
# Check class hierarchy and understand the relationship
echo "=== Searching for class definitions ==="
rg "class StaticBlockScaleQuantizer|class NVFP4StaticAdaRoundQuantizer" -B 2 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 1617
🏁 Script executed:
#!/bin/bash
# Find where state is saved/created for these quantizers
echo "=== Looking for state_dict or similar save patterns ==="
rg "state\[" modelopt/torch/quantization/nn/modules/tensor_quantizer.py -B 2 -A 2
echo ""
echo "=== Check the full from_nvfp4_quantizer implementation ==="
rg "def from_nvfp4_quantizer" -A 30 modelopt/torch/quantization/nn/modules/tensor_quantizer.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1973
🏁 Script executed:
#!/bin/bash
# Find how states are serialized/created
echo "=== Search for places where flags are used in state tracking ==="
rg "get_state|state_dict" modelopt/torch/quantization/nn/modules/tensor_quantizer.py -B 2 -A 2
echo ""
echo "=== Check how quantizers are initialized from saved state ==="
rg "set_from_modelopt_state" -B 5 -A 10 modelopt/torch/quantization/nn/modules/tensor_quantizer.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1117
🏁 Script executed:
#!/bin/bash
# Understand the full restoration flow in conversion.py
echo "=== Full restore_quantizer_state function context ==="
rg "def restore_quantizer_state" -A 20 modelopt/torch/quantization/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1345
🏁 Script executed:
#!/bin/bash
# Check what properties are saved and if there's any validation
echo "=== Check _get_properties_for_modelopt_state ==="
rg "_get_properties_for_modelopt_state" -A 15 modelopt/torch/quantization/nn/modules/tensor_quantizer.py
echo ""
echo "=== Check if flags are explicitly tracked in saved properties ==="
rg "_is_nvfp4" modelopt/torch/quantization/nn/modules/tensor_quantizer.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 1517
🏁 Script executed:
#!/bin/bash
# Check if there's any scenario where flags could be set independently
echo "=== Search for all places where _is_nvfp4_static_quantizer is set ==="
rg "_is_nvfp4_static_quantizer\s*=" modelopt/torch/quantization/
echo ""
echo "=== Search for all places where _is_nvfp4_static_adaround_quantizer is set ==="
rg "_is_nvfp4_static_adaround_quantizer\s*=" modelopt/torch/quantization/
echo ""
echo "=== Check if there's any initialization that sets one without the other ==="
rg "class NVFP4StaticAdaRoundQuantizer" -A 30 modelopt/torch/quantization/nn/modules/tensor_quantizer.py | grep -A 30 "__init__"Repository: NVIDIA/Model-Optimizer
Length of output: 995
🏁 Script executed:
#!/bin/bash
# Check if there's validation or tests that verify state consistency
echo "=== Search for tests of restore_quantizer_state ==="
fd "test.*conversion" modelopt/torch/quantization/tests/ --type f
echo ""
echo "=== Check test files ==="
rg "restore_quantizer_state" modelopt/torch/quantization/tests/ -B 2 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 357
🏁 Script executed:
#!/bin/bash
# Check if there's any documentation about state consistency
echo "=== Check docstrings and comments about state flags ==="
rg "_is_nvfp4" -B 5 modelopt/torch/quantization/nn/modules/tensor_quantizer.py | grep -A 5 -B 5 "is_nvfp4"
echo ""
echo "=== Check if there's validation during state restoration ==="
rg "set_from_modelopt_state" -A 20 modelopt/torch/quantization/nn/modules/tensor_quantizer.py | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 1752
🏁 Script executed:
#!/bin/bash
# Verify the actual restore flow more carefully
echo "=== Full context around restore_quantizer_state for loop ==="
rg "def restore_quantizer_state" -A 50 modelopt/torch/quantization/conversion.py | head -60
echo ""
echo "=== Check if flags are initialized on new instances ==="
rg "class StaticBlockScaleQuantizer" -A 15 modelopt/torch/quantization/nn/modules/tensor_quantizer.py | grep -A 15 "__init__"Repository: NVIDIA/Model-Optimizer
Length of output: 2784
🏁 Script executed:
#!/bin/bash
# Let me trace through the exact order of operations in restore
echo "=== Verify the order: conversion happens BEFORE set_from_modelopt_state ==="
rg "for name, module in model.named_modules" -A 15 modelopt/torch/quantization/conversion.py | head -25Repository: NVIDIA/Model-Optimizer
Length of output: 1196
Add validation to ensure both flags are set together for AdaRound quantizers.
The two conversion checks are independent if statements and execute sequentially. The from_nvfp4_quantizer method requires its input to be a StaticBlockScaleQuantizer (enforced by assertion). If a saved state has _is_nvfp4_static_adaround_quantizer=True but _is_nvfp4_static_quantizer=False, the first conversion is skipped and the second fails on the assertion.
While normal operation ensures both flags are set together when creating an NVFP4StaticAdaRoundQuantizer, there is no explicit validation during state restoration. Consider adding a check to ensure saved states maintain this invariant, or restructure the restoration logic to handle this edge case gracefully (e.g., convert to StaticBlockScaleQuantizer first if needed before checking for AdaRound).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/conversion.py` around lines 132 - 140, The
current sequential checks can call
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer(module) without ensuring
module is first a StaticBlockScaleQuantizer; update the restoration logic in
conversion.py so that when state["_is_nvfp4_static_adaround_quantizer"] is true
you first ensure or convert the module to a StaticBlockScaleQuantizer (i.e., run
StaticBlockScaleQuantizer.from_tensor_quantizer(module) if needed) or explicitly
validate that state["_is_nvfp4_static_quantizer"] is also true and raise/handle
a clear error; specifically touch the block that references
StaticBlockScaleQuantizer, NVFP4StaticAdaRoundQuantizer, from_nvfp4_quantizer,
and set_from_modelopt_state to enforce the invariant or perform the conversion
before calling from_nvfp4_quantizer.
| def from_nvfp4_quantizer( | ||
| cls, | ||
| tq: StaticBlockScaleQuantizer, | ||
| weight_scaled: torch.Tensor | None = None, | ||
| ) -> "NVFP4StaticAdaRoundQuantizer": | ||
| """Convert an NVFP4StaticQuantizer to NVFP4StaticAdaRoundQuantizer in-place. | ||
|
|
||
| Args: | ||
| tq: The NVFP4StaticQuantizer to convert. | ||
| weight_scaled: Pre-scaled weight tensor of shape ``[num_blocks, block_size]``. | ||
| If provided, :meth:`enable_adaround` is called immediately. | ||
| """ | ||
| assert isinstance(tq, StaticBlockScaleQuantizer), ( | ||
| f"Expected StaticBlockScaleQuantizer, got {type(tq)}" | ||
| ) | ||
|
|
||
| if isinstance(tq, cls): | ||
| if weight_scaled is not None: | ||
| tq.enable_adaround(weight_scaled) | ||
| return tq | ||
| tq.__class__ = cls | ||
| tq._is_nvfp4_static_adaround_quantizer = True | ||
| tq._adaround_enabled = False | ||
| tq.temperature = 1.0 | ||
| if weight_scaled is not None: | ||
| tq.enable_adaround(weight_scaled) | ||
| return tq |
There was a problem hiding this comment.
Reject non-FP4 quantizers in from_nvfp4_quantizer().
After this PR StaticBlockScaleQuantizer also represents INT block formats, but this conversion path still initializes AdaRound with fp4_cast_ste() / fp4_step_size() and _cast_ste() stays FP4-only. Converting an INT4/INT8 block quantizer here will learn rounding decisions on the wrong grid. Please guard this to the supported FP4 config, or add an int-specific AdaRound implementation.
Also applies to: 1595-1609, 1638-1648
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/nn/modules/tensor_quantizer.py` around lines 1555
- 1581, The conversion method from_nvfp4_quantizer should reject non-FP4/
non-NVFP4 block quantizers instead of blindly mutating any
StaticBlockScaleQuantizer; add a guard near the start of
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer that inspects the source
quantizer's format/block-type (e.g., a property like block_format / is_fp4 /
nvfp4 flag on tq) and raise/assert if it is not an FP4/NVFP4 config, and do the
same check in the other related conversion methods mentioned (the other
from_nvfp4_quantizer conversion blocks at the ranges you noted) so AdaRound
remains FP4-only unless you implement a separate INT-specific AdaRound path.
| def _setup_adaround(self): | ||
| """Set up AdaRound: register aux callback and freeze parent weights.""" | ||
| self._adaround_pending_metrics = {} | ||
| self.add_callback(_AdaRoundAuxCallback(trainer=self)) | ||
| self._freeze_adaround_weights() |
There was a problem hiding this comment.
Propagate adaround_args.temperature into the quantizers.
AdaRoundTrainingArguments.temperature is stored on the trainer, but _setup_adaround() never writes it to any NVFP4StaticAdaRoundQuantizer. Non-default values are ignored and AdaRound always runs at temperature = 1.0.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/plugins/transformers_trainer.py` around lines 308
- 312, The trainer's _setup_adaround() registers the AdaRound callback but never
propagates AdaRoundTrainingArguments.temperature to the quantizers, so
non-default temperatures are ignored; update _setup_adaround() to iterate the
model's quantizers (e.g., instances of NVFP4StaticAdaRoundQuantizer) and set
their temperature property from self.args.ada_round_args.temperature (or
self.ada_rounding_args / AdaRoundTrainingArguments.temperature where stored)
before adding the _AdaRoundAuxCallback, ensuring each
NVFP4StaticAdaRoundQuantizer.temperature (or equivalent attribute) is assigned
the trainer's temperature value so AdaRound uses the configured temperature.
| model = self._trainer.accelerator.unwrap_model(self._trainer.model) | ||
| self._weight_entries = [] # list of (weight_param, quantizer) | ||
| for _name, module in model.named_modules(): | ||
| for weight_name in weight_attr_names(module): | ||
| wq_name = quantizer_attr_names(weight_name).weight_quantizer | ||
| quantizer = getattr(module, wq_name, None) | ||
| if quantizer is None or not quantizer.is_enabled: | ||
| continue | ||
| weight = getattr(module, weight_name, None) | ||
| if not isinstance(weight, torch.nn.Parameter): | ||
| continue | ||
| self._weight_entries.append((weight, quantizer)) | ||
|
|
||
| pid_to_group = {} | ||
| for group_idx, group in enumerate(self._trainer.optimizer.param_groups): | ||
| for p in group["params"]: | ||
| pid_to_group[id(p)] = group_idx | ||
| self._param_group_idx = {} | ||
| self._multiplier = {} | ||
| for weight, _q in self._weight_entries: | ||
| self._param_group_idx[id(weight)] = pid_to_group.get(id(weight), 0) | ||
| self._multiplier[id(weight)] = torch.zeros(1, device=weight.device) | ||
|
|
There was a problem hiding this comment.
QERR should ignore frozen weights.
This helper registers every quantized weight, even when it has requires_grad=False or is absent from the optimizer. In LoRA / trainable_params runs, the qerr_coeff > 0 path backprops through those tensors and later uses weight.grad even though it is None. Filter _weight_entries down to trainable optimizer-owned weights here.
Also applies to: 717-728
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/plugins/transformers_trainer.py` around lines 650
- 672, The current registration loop (_weight_entries) includes quantized
weights even when they are frozen or not in the optimizer, causing qerr to
backprop through tensors with no grads; update the loop that builds
self._weight_entries (the block using weight_attr_names, quantizer_attr_names,
weight_quantizer and is_enabled) to only append weights that are
torch.nn.Parameter AND weight.requires_grad is True AND the weight's id is
present in the optimizer param mapping (pid_to_group); similarly ensure the
later population of self._param_group_idx and self._multiplier only iterates
over these filtered _weight_entries (so id(weight) exists in pid_to_group) to
avoid creating entries for optimizer-less/frozen weights — apply the same guard
to the analogous registration block elsewhere that performs the same work.
| def _compute_mse(self, weight, quantizer): | ||
| """Compute MSE between original and quantized weight.""" | ||
| if isinstance(quantizer, StaticBlockScaleQuantizer) and not ( | ||
| quantizer._lsq or quantizer._laq or quantizer._smooth_laq | ||
| ): | ||
| # smooth_lsq: weights are pre-divided, use raw cast | ||
| orig_shape = weight.shape | ||
| if hasattr(quantizer, "_block_reshape_size"): | ||
| w = weight.reshape(quantizer._block_reshape_size) | ||
| else: | ||
| w = weight | ||
| q_weight = quantizer._cast_ste(w) | ||
| q_weight = q_weight.reshape(orig_shape) | ||
| else: | ||
| q_weight = quantizer(weight) | ||
| sq_err = (q_weight.detach() - weight) ** 2 | ||
| if self._trainer.qerr_args.qerr_reduction == "sum": | ||
| return sq_err.sum() | ||
| return sq_err.mean() | ||
|
|
There was a problem hiding this comment.
The _cast_ste shortcut is wrong for plain static-block quantizers.
This branch also catches a calibrated StaticBlockScaleQuantizer with no learnable mode enabled. In that case the stored weight is still in real scale, so _cast_ste(w) drops the block/tensor scales and the QERR metric/update targets the wrong grid. Keep the shortcut for pre-divided smooth_lsq weights only; use quantizer(weight) for the other static-block cases.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/plugins/transformers_trainer.py` around lines 676
- 695, The branch in _compute_mse incorrectly uses quantizer._cast_ste for any
StaticBlockScaleQuantizer that has no learnable modes, which incorrectly drops
block/tensor scales for calibrated (non-pre-divided) weights; change the
condition so _cast_ste is used only when the quantizer is a
pre-divided/smooth_lsq case (i.e., when quantizer._smooth_lsq is True), and for
other StaticBlockScaleQuantizer instances call quantizer(weight) instead;
preserve the existing reshape logic using quantizer._block_reshape_size and then
reshape back to orig_shape before computing sq_err.
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: This PR introduces five new quantization algorithms (SmoothLSQ, LSQ, LAQ, SmoothLAQ, AdaRound) with learnable-scale quantizer support, refactors the base quantizer class from NVFP4StaticQuantizer to StaticBlockScaleQuantizer, adds training infrastructure (AdaRound dist_loss callback, quantization error regularization, Liger kernel fusion, per-parameter LR config), and extends the QAT example with local dataset support and caching. This is a very large PR (~3500 lines of diff across 23 files).
Issues Found:
-
[Correctness] Hardcoded internal path in
create_ptq.py:138
The default--datasetargument is/home/scratch.akuriparambi_coreai/datasets/qat_blend_sft/blend_sft.jsonl— an internal filesystem path that won't exist outside NVIDIA. This should be changed to require the argument or use a publicly available dataset. -
[Correctness] Tokenization format change in
utils.py:_make_tokenize_fn
The new tokenizer prepends role labels (f"{role}: {content}\n", e.g."User: hello\n") whereas the originalget_daring_anteatertokenized only the conversation value (conversation["value"] + "\n"). This silently changes the tokenization semantics for the Daring-Anteater dataset path, which flows through_normalize_to_messages→_make_tokenize_fn. The role prefix tokens weren't in the original training data and will affect loss masking since only"Assistant"role tokens get labels. -
[Correctness]
_load_cached_datasettokenizes then splits vs original split-then-tokenize
The originalget_daring_anteatershuffled, selected, tokenized (via.map()), then didtrain_test_split. The new_load_cached_datasetshuffles, selects, splits, then tokenizes per split. While functionally similar, the split happens on raw data vs tokenized data, meaning the seed-42 split will produce different train/test assignments than before — a silent regression for anyone comparing against old results. -
[Correctness]
launch.shsilently swallows unknown args
Lines 52-53 change the behavior from erroring on unknown args to collecting them inEXTRA_ARGS. This is convenient for pass-through but means typos in argument names (e.g.--ouput_dir) will silently be passed tomain.pyrather than caught early. At minimum, a warning would be appropriate. -
[Correctness]
model_quant.py:quantize()skip-quantize path
The new logic at line 231-236 skipsapply_modefor already-quantized models and only callsset_quantizer_by_cfg. This means if a user callsquantize()twice with a differentquant_cfg, the module structure won't be updated (e.g., no new quantizer modules added), only existing quantizer attributes get set. This could silently produce wrong results if the second config targets quantizers that don't exist yet. The behavior change should at minimum be documented. -
[Correctness]
_apply_gradient_checkpointing_defaultssilently overrides user setting
Intransformers.py:207-215,use_reentrant=Trueis forcibly overridden toFalsewith only a warning. The original code inmain.pysetuse_reentrant=True(now removed). Whileuse_reentrant=Falseis generally better, forcibly overriding a user's explicit setting is surprising behavior for a base trainer class. -
[Correctness] QATTrainer rejects non-distributed mode
Lines 258-262 raiseValueErrorforParallelMode.NOT_DISTRIBUTED, but the test fixtures monkeypatch toNOT_PARALLEL. This means the tests bypass a check that would fire in real single-GPU usage. If single-GPU is truly unsupported, this needs clearer documentation; if it should work, the check needs refinement. -
[Duplicated Code]
is_static_block_scalecheck repeated 3+ times
The pattern checkingmodule.is_static_block_quant and module._block_sizes is not None and ((module._num_bits == (2, 1) and ...) or isinstance(module._num_bits, int))appears inmse_calibrate,local_hessian_calibrate, and_convert_to_static_block_quantizers. Extract this to a helper like_is_eligible_for_static_block_scale(quantizer). -
[Duplicated Code] Config class boilerplate
SmoothLSQConfig,LSQConfig,LAQConfig,SmoothLAQConfiginconfig.pyare nearly identical — each has only amethodliteral and identicalscale_algorithmfield with the same description. Consider a base class or factory to reduce the ~100 lines of duplication. -
[Readability]
StaticBlockScaleQuantizer._fake_quantizeis very long
The_fake_quantizemethod handles 6+ code paths (smooth_lsq, lsq, laq, smooth_laq, FP4 static, INT static, fallback). This 50+ line method with deeply nested conditions would benefit from being split into named helpers (e.g._fake_quantize_learnable,_fake_quantize_static). -
[Readability] Class-level mutable state as class attributes
StaticBlockScaleQuantizerdefines_smooth_lsq,_lsq,_laq,_smooth_laqas class-level booleans (line 1293-1296 in the diff). These are instance-level state mutated byenable_*methods. While Python handles this correctly (instance assignment shadows class attribute), it's a confusing pattern — these should be set infrom_tensor_quantizeror__init__. -
[Tests] Missing test for tokenization change
The tokenization format change in_make_tokenize_fn(prepending role labels) has no test coverage. The existing tests only cover the quantization algorithms and trainer integration. -
[Correctness]
torch.float32default dtype change inmain.py:202
Changed fromtorch_dtype=torch.bfloat16totorch.float32with a comment about mixed precision. This doubles the memory requirement for model loading. The justification comment is thin — mixed precision training typically handles bf16 casting itself, but loading in fp32 means the model sits in fp32 until the first forward pass. This could cause OOM on memory-constrained setups. Should at least be configurable. -
[Correctness]
_QuantErrorAuxCallback._compute_msecallsmse.backward()outside autocast
The QERR callback (line 735 intransformers_trainer.py) callsmse.backward()on a manually computed MSE. If mixed precision is active, this backward pass happens outside the trainer's autocast context, which could lead to dtype mismatches or suboptimal gradient precision.
Suggestions:
- The PR is very large (23 files, ~3500 lines). Consider splitting into: (1) StaticBlockScaleQuantizer refactor, (2) learnable-scale algorithms, (3) AdaRound, (4) training infrastructure (Liger, LR config, QERR), (5) example/dataset improvements.
- Add a migration note for the
NVFP4StaticQuantizer→StaticBlockScaleQuantizerrename, even though the alias is preserved. - The
_dataset_cachemodule-level dict inutils.pyis a global mutable singleton that persists across calls — consider documenting this clearly or using a more explicit caching mechanism.
Overall Assessment: The algorithmic work (learnable-scale quantizers, AdaRound) appears well-designed with good test coverage for the core quantization paths. However, the PR bundles too many orthogonal changes (training infra, dataset handling, Liger integration, gradient checkpointing policy changes) making it hard to review safely. The tokenization format change and the quantize() skip-path are the highest-risk correctness concerns. The hardcoded internal path must be fixed before merge.
What does this PR do?
Type of change: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests