diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 2a29f644e..96f75f8a8 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -319,15 +319,96 @@ trainer.save_state() trainer.save_model("") ``` +## DFlash: Block Diffusion for Flash Speculative Decoding + +DFlash ([arXiv:2602.06036](https://arxiv.org/abs/2602.06036)) is a parallel speculative decoding method that predicts multiple tokens simultaneously using block diffusion. Unlike autoregressive methods (EAGLE, Medusa) that draft one token at a time, DFlash predicts an entire block of tokens in parallel, then iteratively denoises them. + +### Architecture + +DFlash uses three key mechanisms: + +- **Feature Fusion**: Multi-layer hidden states from the target model are projected via a fully-connected layer and RMSNorm to create context features +- **KV Injection**: Context features are injected as K/V in every draft decoder layer, while Q comes from the noise embeddings. QK-Norm (RMSNorm on Q and K before RoPE) stabilizes attention +- **Parallel Drafting**: Within each block of size B, unknown positions use a `mask_token_id` token. Only block-start positions get the real token. The attention mask allows noise tokens to attend to all context tokens from previous blocks, plus causally within the same block + +### Training + +```bash +./launch_train.sh --model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --data input_conversations/train.jsonl \ + --num_epochs $NUM_EPOCH \ + --mode dflash \ + --dflash_block_size 16 \ + --dflash_num_layers 5 +``` + +Key arguments: + +| Flag | Default | Description | +|------|---------|-------------| +| `--mode dflash` | - | Enable DFlash mode | +| `--dflash_block_size` | 16 | Block size for parallel prediction | +| `--dflash_num_layers` | 5 | Number of decoder layers in draft module | +| `--dflash_config` | None | Path to JSON config for custom architecture | +| `--dflash_mask_token_id` | auto | Mask token ID (auto-detected from model) | +| `--dflash_disable_torch_compile` | False | Disable torch.compile | +| `--dflash_use_logit_distillation` | False | Use KD from target model logits instead of hard CE | + +### mask_token_id + +The `mask_token_id` is critical for DFlash training and inference. It must be consistent between training and deployment. Auto-detection logic: + +| Model Family | mask_token_id | Source | +|-------------|---------------|--------| +| Qwen3.5 | 248070 | Built-in `[MASK]` token | +| Qwen3 (8B) | 151643 | `eos_token_id` | +| Llama 3 | 128002 | `reserved_special_token_0` | +| Others | `pad_token_id` | Fallback | + +Override with `--dflash_mask_token_id ` if auto-detection is incorrect. + +### Configuring Draft Model + +Similar to EAGLE, provide a JSON config to customize the draft architecture: + +```json +{ + "num_hidden_layers": 5, + "rms_norm_eps": 1e-6 +} +``` + +Model dimensions (hidden_size, num_attention_heads, etc.) are automatically inherited from the base model. + +### Current Status (WIP) + +| Feature | Status | +|---------|--------| +| Architecture (Feature Fusion, KV Injection, Parallel Drafting) | Working | +| Online training with HF Trainer | Working | +| Inference / AR validation (`pseudo_speculative_generate`) | Working | +| z-lab checkpoint loading and inference (AR 7-9) | Working | +| Logit distillation option | Working | +| Response-only loss masking | Working | +| DDP training | Working (with `find_unused_parameters=True`) | + +**Known gap**: Training with ModelOpt achieves ~35% per-token accuracy (matching SpecForge's ~30%), but acceptance rate (AR) is lower than SpecForge-trained checkpoints (1.15 vs 1.95). Investigation shows the **data pipeline** differs significantly: + +- SpecForge uses its own tokenizer template with system prompt and response-only loss mask +- ModelOpt's `LanguageDataCollator` uses `apply_chat_template` with different formatting + +Aligning the data pipeline is the next step to close the AR gap. + ## Support Matrix -| Model | Medusa | EAGLE1/2 | EAGLE3 | -| :---: | :---: | :---: | :---: | -| LLAMA 2 | ✅ | ✅ | ✅ | -| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | -| Mistral | ✅ | ✅ | ✅ | -| Phi 3 | ✅ | ✅ | ✅ | -| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | +| Model | Medusa | EAGLE1/2 | EAGLE3 | DFlash | +| :---: | :---: | :---: | :---: | :---: | +| LLAMA 2 | ✅ | ✅ | ✅ | ✅ | +| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | ✅ | +| Mistral | ✅ | ✅ | ✅ | ✅ | +| Phi 3 | ✅ | ✅ | ✅ | ✅ | +| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | ✅ | ## Speculation Module Checkpoints diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 1bc7df981..6f1ba87bb 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -139,6 +139,7 @@ def make_eagle_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, train_len=None, + answer_only_loss=False, ) -> dict: if data_args.offline_data_path is None: train_dataset = ShardedDataset("json", data_files=data_args.data_path) @@ -148,6 +149,7 @@ def make_eagle_supervised_data_module( tokenizer=tokenizer, train_len=train_len, return_labels=True, + answer_only_loss=answer_only_loss, ) else: data_collator = VisionLanguageDataCollator( @@ -203,6 +205,12 @@ def on_log(self, args, state, control, **kwargs): if not hasattr(state, "training_accs") or len(state.training_accs) == 0: return control average_acc = np.mean(state.training_accs, axis=0) + # Always print accuracy to console + try: + acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten()) + print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]") + except Exception: + print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}") if self.estimate_ar: # Calculate mean training AR since last log # NOTE: This is only an estimate of the real AR. @@ -235,23 +243,31 @@ def on_log(self, args, state, control, **kwargs): return control def on_step_end(self, args, state, control, **kwargs): - """Run AR validation periodically, if available.""" + """Run AR validation periodically, if available. + + Only runs on rank 0 to avoid DDP deadlock — other ranks skip and + synchronize via barrier. + """ if self.ar_validate_steps <= 0: return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: - print_rank_0("Running AR validation...") - try: - ars = validate_ar( - model=kwargs["model"], - tokenizer=kwargs["processing_class"], - ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], - device=kwargs["model"].device, - ) - print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb and is_master(): - wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) - except Exception: - print_rank_0("AR validation not available.") + if is_master(): + print_rank_0("Running AR validation...") + try: + ars = validate_ar( + model=kwargs["model"], + tokenizer=kwargs["processing_class"], + ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], + device=kwargs["model"].device, + ) + print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") + if wandb: + wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) + except Exception: + print_rank_0("AR validation not available.") + # Barrier to synchronize all ranks after validation + if torch.distributed.is_initialized(): + torch.distributed.barrier() return control diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c15b97bda..7beb56742 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -134,6 +134,22 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi FSDP="${1#*=}" ;; + --dflash_block_size*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_BLOCK_SIZE="${1#*=}" + ;; + --dflash_num_layers*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_NUM_LAYERS="${1#*=}" + ;; + --dflash_config*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_CONFIG="${1#*=}" + ;; + --dflash_mask_token_id*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_MASK_TOKEN_ID="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -195,8 +211,20 @@ if [[ "$MODE" == "eagle3" ]]; then else SPECULATIVE_ARGS="" fi +elif [[ "$MODE" == "dflash" ]]; then + DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} + DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} + SPECULATIVE_ARGS="--dflash_block_size $DFLASH_BLOCK_SIZE --dflash_num_layers $DFLASH_NUM_LAYERS --dflash_disable_torch_compile" + if [[ -n "$DFLASH_CONFIG" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_config $DFLASH_CONFIG" + fi + if [[ -n "$DFLASH_MASK_TOKEN_ID" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_mask_token_id $DFLASH_MASK_TOKEN_ID" + fi + # DFlash uses DDP instead of FSDP + FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800" else - echo "Only eagle3 supported for now!" + echo "Unsupported mode: $MODE. Supported: eagle3, dflash" exit 1 fi @@ -218,12 +246,14 @@ else VLM_ARGS="" fi -if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then - #Use FSDP2 when multi GPU available - FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" -else - #Otherwise, single GPU training - FSDP_ARGS="" +if [[ "$MODE" != "dflash" ]]; then + if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then + #Use FSDP2 when multi GPU available + FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" + else + #Otherwise, single GPU training + FSDP_ARGS="" + fi fi diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3369d399c..650ff4433 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -104,7 +104,7 @@ class TrainingArguments(transformers.TrainingArguments): ) dataloader_drop_last: bool = field(default=True) bf16: bool = field(default=True) - mode: Literal["eagle3", "medusa"] = "eagle3" + mode: Literal["eagle3", "medusa", "dflash"] = "eagle3" estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR during training for logging."} ) @@ -144,6 +144,32 @@ class EagleArguments: ) +@dataclass +class DFlashArguments: + dflash_block_size: int = field( + default=16, metadata={"help": "Block size for DFlash parallel prediction."} + ) + dflash_num_layers: int = field( + default=5, metadata={"help": "Number of decoder layers in the DFlash draft module."} + ) + dflash_config: str | None = field(default=None, metadata={"help": "Path to dflash_config.json"}) + dflash_disable_torch_compile: bool = field( + default=False, + metadata={"help": "Disable torch.compile on DFlash forward/loss methods."}, + ) + dflash_mask_token_id: int | None = field( + default=None, + metadata={"help": "Mask token ID for DFlash. If not set, auto-detected from model."}, + ) + dflash_use_logit_distillation: bool = field( + default=False, + metadata={ + "help": "Use logit distillation (KD from target model) instead of hard CE. " + "Enables training with data not synthesized by the target model." + }, + ) + + def train(): parser = transformers.HfArgumentParser( ( @@ -152,9 +178,10 @@ def train(): TrainingArguments, MedusaArguments, EagleArguments, + DFlashArguments, ) ) - model_args, data_args, training_args, medusa_args, eagle_args = ( + model_args, data_args, training_args, medusa_args, eagle_args, dflash_args = ( parser.parse_args_into_dataclasses() ) if not data_args.data_path and not data_args.offline_data_path: @@ -236,13 +263,32 @@ def train(): ) model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") + elif training_args.mode == "dflash": + custom_config = ( + json.load(open(dflash_args.dflash_config)) if dflash_args.dflash_config else {} + ) + custom_config.setdefault("num_hidden_layers", dflash_args.dflash_num_layers) + if dflash_args.dflash_mask_token_id is not None: + custom_config["mask_token_id"] = dflash_args.dflash_mask_token_id + + config = { + "dflash_block_size": dflash_args.dflash_block_size, + "dflash_use_torch_compile": not dflash_args.dflash_disable_torch_compile, + "dflash_self_logit_distillation": dflash_args.dflash_use_logit_distillation, + "dflash_architecture_config": custom_config, + } + + mtsp.convert(model, [("dflash", config)]) else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") - if training_args.mode == "eagle3": + if training_args.mode in ("eagle3", "dflash"): data_module = make_eagle_supervised_data_module( - tokenizer, data_args, train_len=training_args.training_seq_len + tokenizer, + data_args, + train_len=training_args.training_seq_len, + answer_only_loss=(training_args.mode == "dflash"), ) trainer = EagleTrainerWithAccLog( diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 69491c659..59aa98db4 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -46,6 +46,54 @@ } +def _get_dflash_default_config(): + from .dflash.default_config import default_dflash_config + + return default_dflash_config + + +DFLASH_DEFAULT_CFG = { + "algorithm": "dflash", + "config": { + "dflash_architecture_config": {}, # merged with default at convert time + }, +} + + +class DFlashConfig(ModeloptBaseConfig): + """DFlash config for block-wise parallel speculative decoding.""" + + dflash_block_size: int = ModeloptField( + default=16, + description="Block size for parallel prediction. Draft predicts this many tokens per block.", + ) + + dflash_freeze_base_model: bool = ModeloptField( + default=True, description="Whether to freeze base model during DFlash module training." + ) + + dflash_self_logit_distillation: bool = ModeloptField( + default=True, description="Whether to use logit distillation from base model." + ) + + dflash_loss_decay_factor: float = ModeloptField( + default=0.9, description="Decay factor for per-block loss weighting." + ) + + dflash_report_acc: bool = ModeloptField( + default=True, description="Whether to report eval accuracy." + ) + + dflash_architecture_config: dict = ModeloptField( + default={}, description="Config for the DFlash draft module architecture." + ) + + dflash_use_torch_compile: bool = ModeloptField( + default=True, + description="Whether to use torch.compile on DFlash forward/loss methods.", + ) + + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/__init__.py b/modelopt/torch/speculative/dflash/__init__.py new file mode 100644 index 000000000..912b8d47a --- /dev/null +++ b/modelopt/torch/speculative/dflash/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash Optimization Method.""" + +from .conversion import * +from .default_config import * +from .dflash_model import * diff --git a/modelopt/torch/speculative/dflash/conversion.py b/modelopt/torch/speculative/dflash/conversion.py new file mode 100644 index 000000000..943be90ca --- /dev/null +++ b/modelopt/torch/speculative/dflash/conversion.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash conversion/restore utilities.""" + +from torch import nn + +from modelopt.torch.opt.conversion import ModelLikeModule +from modelopt.torch.opt.dynamic import _DMRegistryCls +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict + +from ..config import DFlashConfig + +DFlashDMRegistry = _DMRegistryCls(prefix="DFlash") # global instance for the registry + + +def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType: + """Convert the model to a DFlash model as per `config`.""" + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + original_cls = type(model) + if original_cls not in DFlashDMRegistry: + for cls in DFlashDMRegistry._registry: + if issubclass(original_cls, cls): + DFlashDMRegistry.register({original_cls: "base_model_class"})(DFlashDMRegistry[cls]) + break + + # merge custom config with default config (lazy import to avoid circular) + from .default_config import default_dflash_config + + custom_config = config.dflash_architecture_config + config.dflash_architecture_config = {**default_dflash_config, **custom_config} + + dflash_model = DFlashDMRegistry.convert(model) + dflash_model.modify(config) + + metadata = {} + return dflash_model, metadata + + +def restore_dflash_model( + model: nn.Module, config: DFlashConfig, metadata: MetadataDict +) -> nn.Module: + """Function for restoring a previously converted model to a DFlash model.""" + assert not metadata, "No metadata expected!" + return convert_to_dflash_model(model, config)[0] diff --git a/modelopt/torch/speculative/dflash/default_config.py b/modelopt/torch/speculative/dflash/default_config.py new file mode 100644 index 000000000..5536e0d4d --- /dev/null +++ b/modelopt/torch/speculative/dflash/default_config.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default DFlash architecture config. + +Model-specific settings (hidden_size, num_attention_heads, rope_*, etc.) +are inherited from the base model in HFDFlashModel.modify(). Only +DFlash-specific defaults are set here. +""" + +default_dflash_config = { + "num_hidden_layers": 5, + "rms_norm_eps": 1e-06, + "attention_bias": False, + "attention_dropout": 0.0, +} diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py new file mode 100644 index 000000000..e44b17b50 --- /dev/null +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash model to support block-wise parallel speculative decoding.""" + +from modelopt.torch.opt.dynamic import DynamicModule + + +class DFlashModel(DynamicModule): + """Base DFlash Model.""" + + def _setup(self): + """Register temporary attributes for the DFlash module.""" + self._register_temp_attribute("dflash_module", None) + + def modify(self, config): + """Base DFlash Model modify function. Child class should implement the details.""" + self.dflash_block_size = config.dflash_block_size + self.dflash_freeze_base_model = config.dflash_freeze_base_model + self.dflash_loss_decay_factor = config.dflash_loss_decay_factor + self.dflash_self_logit_distillation = config.dflash_self_logit_distillation + self.dflash_report_acc = config.dflash_report_acc + self.dflash_use_torch_compile = config.dflash_use_torch_compile diff --git a/modelopt/torch/speculative/mode.py b/modelopt/torch/speculative/mode.py index 866449e15..ae965354a 100644 --- a/modelopt/torch/speculative/mode.py +++ b/modelopt/torch/speculative/mode.py @@ -23,7 +23,8 @@ _ModeRegistryCls, ) -from .config import EagleConfig, MedusaConfig +from .config import DFlashConfig, EagleConfig, MedusaConfig +from .dflash.conversion import convert_to_dflash_model, restore_dflash_model from .eagle.conversion import convert_to_eagle_model, restore_eagle_model from .medusa.conversion import convert_to_medusa_model, restore_medusa_model @@ -58,6 +59,34 @@ def restore(self) -> RestoreEntrypoint: return restore_medusa_model +@SpeculativeDecodingModeRegistry.register_mode +class DFlashModeDescriptor(ModeDescriptor): + """Class to describe the ``"dflash"`` mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "dflash" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return DFlashConfig + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_dflash_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_dflash_model + + @SpeculativeDecodingModeRegistry.register_mode class EagleModeDescriptor(ModeDescriptor): """Class to describe the ``"eagle"`` mode. diff --git a/modelopt/torch/speculative/plugins/__init__.py b/modelopt/torch/speculative/plugins/__init__.py index 5e3f4bff2..d59aed37d 100644 --- a/modelopt/torch/speculative/plugins/__init__.py +++ b/modelopt/torch/speculative/plugins/__init__.py @@ -31,3 +31,6 @@ with import_plugin("transformers"): from .transformers import * + +with import_plugin("hf_dflash"): + from .hf_dflash import * diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py new file mode 100644 index 000000000..ffa2ed354 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -0,0 +1,707 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash speculative decoding plugin for HuggingFace models. + +Matches the reference SpecForge implementation (github.com/sgl-project/SpecForge PR #415). + +Architecture: +- Feature Fusion: multi-layer target hidden states → FC + RMSNorm +- KV Injection: fused features as K/V in every draft layer with QK-norm +- Parallel Drafting: mask_token_id for unknown positions, causal within blocks +- Loss: hard CE on input_ids[i] (position i predicts token i) + +Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +""" + +import importlib + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel +from transformers.utils import ModelOutput + +from ..dflash.conversion import DFlashDMRegistry +from ..dflash.dflash_model import DFlashModel + + +def _resolve_model_components(model_type): + """Resolve MLP, RMSNorm, RotaryEmbedding from the base model's transformers module. + + Falls back to Llama components if the model type is unknown. + """ + fallback = "llama" + model_type = model_type or fallback + try: + mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}") + except (ImportError, ModuleNotFoundError): + mod = importlib.import_module(f"transformers.models.{fallback}.modeling_{fallback}") + model_type = fallback + + prefix = model_type.capitalize() + # Handle multi-word model types (e.g., "qwen3" -> "Qwen3") + for attr in dir(mod): + if attr.lower() == f"{model_type}mlp": + prefix = attr.replace("MLP", "") + break + + mlp_cls = getattr(mod, f"{prefix}MLP", None) + norm_cls = getattr(mod, f"{prefix}RMSNorm", None) + rotary_cls = getattr(mod, f"{prefix}RotaryEmbedding", None) + rotate_half_fn = getattr(mod, "rotate_half", None) + + # Fallback to Llama if any component is missing + if not all([mlp_cls, norm_cls, rotary_cls, rotate_half_fn]): + from transformers.models.llama.modeling_llama import ( + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + ) + from transformers.models.llama.modeling_llama import rotate_half as _rotate_half + + mlp_cls = mlp_cls or LlamaMLP + norm_cls = norm_cls or LlamaRMSNorm + rotary_cls = rotary_cls or LlamaRotaryEmbedding + rotate_half_fn = rotate_half_fn or _rotate_half + + return mlp_cls, norm_cls, rotary_cls, rotate_half_fn + + +# Default to Llama components; overridden per-model during convert() +_MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components("llama") + +__all__ = ["HFDFlashModel"] + + +def build_target_layer_ids(num_target_layers, num_draft_layers): + """Select layers uniformly from the target model for feature extraction.""" + if num_draft_layers == 1: + return [num_target_layers // 2] + start = 1 + end = num_target_layers - 3 + span = end - start + return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] + + +def apply_rotary_pos_emb(q, k, cos, sin): + """Apply RoPE. Q uses last q_len positions, K uses all positions.""" + cos = cos.unsqueeze(1) # [B, 1, seq, dim] + sin = sin.unsqueeze(1) + q_len = q.size(2) + q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :]) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class DFlashAttention(nn.Module): + """Attention with KV injection, matching SpecForge Qwen3DFlashAttention.""" + + def __init__(self, config, layer_idx): + """Initialize DFlash attention with KV injection projections and QK-norm.""" + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + + # QK norm (matches Qwen3DFlashAttention) + self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward with KV injection: Q from noise, K/V from context+noise.""" + bsz, q_len, _ = hidden_states.shape + ctx_len = target_hidden.shape[1] + + # Q from noise only, with QK-norm + q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + + # K from context + noise, with QK-norm + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view( + bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim + ) + k = self.k_norm(k).transpose(1, 2) + + # V from context + noise (no norm) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + v = ( + torch.cat([v_ctx, v_noise], dim=1) + .view(bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + ) + + # RoPE: applied to full 2L positions, Q gets last q_len, K gets all + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + # GQA expand + if self.num_kv_heads != self.num_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, is_causal=False, scale=self.scaling + ) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) + return self.o_proj(attn_output) + + +class DFlashDecoderLayer(nn.Module): + """Draft decoder layer with KV injection.""" + + def __init__(self, config, layer_idx): + """Initialize decoder layer with attention, MLP, and layer norms.""" + super().__init__() + self.self_attn = DFlashAttention(config, layer_idx) + self.mlp = _MLP_CLS(config) + self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward pass with residual connections.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states, target_hidden, position_embeddings, attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DFlashModule(nn.Module): + """DFlash draft module matching SpecForge DFlashDraftModel.""" + + def __init__(self, config): + """Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings.""" + super().__init__() + self.config = config + self.block_size = config.block_size + + # Feature fusion + num_fused_layers = len(config.target_layer_ids) + self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + # Decoder layers + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = _ROTARY_CLS(config=config) + + def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): + """Forward matching SpecForge DFlashDraftModel.forward.""" + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask) + + return self.norm(hidden_states) + + +def create_dflash_attention_mask(seq_len, block_size, device, dtype): + """Create [L, 2L] attention mask matching SpecForge. + + Context (cols 0..L-1): Block B sees blocks 0..B-1 (strictly previous). + Noise (cols L..2L-1): causal within same block only. + """ + indices = torch.arange(seq_len, device=device) + block_ids = indices // block_size + + q_block_ids = block_ids.unsqueeze(1) # [L, 1] + k_block_ids = block_ids.unsqueeze(0) # [1, L] + + ctx_mask = k_block_ids < q_block_ids + same_block = q_block_ids == k_block_ids + causal = indices.unsqueeze(0) >= indices.unsqueeze(1) # matching SpecForge: j >= i + noise_mask = same_block & causal + + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) + + full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=dtype) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(dtype).min) + + return full_mask.unsqueeze(0).unsqueeze(0) # [1, 1, L, 2L] + + +def create_dflash_loss_mask(seq_len, block_size, device): + """Create loss mask: exclude Block 0 and block starts.""" + positions = torch.arange(seq_len, device=device) + block_ids = positions // block_size + is_block_0 = block_ids == 0 + is_block_start = (positions % block_size) == 0 + return (~is_block_0 & ~is_block_start).float() + + +@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +class HFDFlashModel(DFlashModel): + """DFlash Model matching SpecForge OnlineDFlashModel.""" + + @property + def _base_model(self): + return self.get_submodule(self.base_model_path) + + @property + def _base_model_embeddings(self): + return self.get_submodule(self.base_model_embeddings_path) + + @property + def _base_model_lm_head(self): + return self.get_submodule(self.base_model_lm_head_path) + + @property + def _base_llm_config(self): + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) + + @staticmethod + def _auto_detect_mask_token_id(base_config): + """Auto-detect an appropriate mask token ID for DFlash. + + Different model families use different strategies: + - Qwen3/3.5: built-in [MASK] token in vocabulary + - Llama3: reserved special tokens (128002 = reserved_special_token_0) + - Others: try tokenizer.mask_token_id, then fall back to pad/eos + """ + model_type = getattr(base_config, "model_type", "") + vocab_size = getattr(base_config, "vocab_size", 0) + + # Qwen3/3.5: known mask token positions + if "qwen3" in model_type.lower() or "qwen" in model_type.lower(): + # Qwen3 vocab has dedicated mask tokens + # Qwen3.5-4B: 248070, Qwen3-8B: similar range + # Heuristic: eos_token_id + some offset, or check known values + eos = getattr(base_config, "eos_token_id", None) + if isinstance(eos, list): + eos = eos[0] + if eos and vocab_size > 200000: + # Large Qwen vocab — mask token is typically near end of special tokens + # Known: Qwen3.5 eos=248044, mask=248070 (offset ~26) + # Try common offsets + for offset in [26, 25, 24]: + candidate = eos + offset + if candidate < vocab_size: + return candidate + # Fallback for smaller Qwen models + if vocab_size > 150000: + return vocab_size - 250 # heuristic for Qwen special token region + + # Llama3: use reserved_special_token_0 (128002) + if "llama" in model_type.lower(): + if vocab_size >= 128256: # Llama3 vocab size + return 128002 # <|reserved_special_token_0|> + + # Generic: try pad_token_id, then eos + pad_id = getattr(base_config, "pad_token_id", None) + eos_id = getattr(base_config, "eos_token_id", None) + if isinstance(eos_id, list): + eos_id = eos_id[0] + + # Prefer pad over eos (pad is less likely to interfere) + if pad_id is not None and pad_id != eos_id: + return pad_id + + # Last resort + return eos_id or 0 + + def _find_base_model_parts(self): + """Locate base model submodules (backbone, embeddings, lm_head) by probing known paths.""" + for name, paths in { + "base_model_path": ["model.language_model", "model", "backbone"], + "base_model_embeddings_path": [ + "model.embed_tokens", + "backbone.embeddings", + "model.language_model.embed_tokens", + ], + "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], + }.items(): + for path in paths: + try: + submodule = self.get_submodule(path) + assert isinstance(submodule, torch.nn.Module) + setattr(self, name, path) + break + except Exception: + continue + else: + raise ValueError(f"Part {name} not found in model") + + def modify(self, config): + """Initialize DFlash draft module.""" + super().modify(config) + + base_config = self._base_llm_config + self.dflash_config = PretrainedConfig.from_dict(config.dflash_architecture_config) + + # Inherit settings from base model, but only those NOT already in the user config. + # hidden_size and vocab_size MUST match. Others (heads, intermediate_size) can differ. + # This allows the draft model to have a different architecture than the base model. + self.dflash_config.hidden_size = base_config.hidden_size + self.dflash_config.vocab_size = base_config.vocab_size + + # These use base model defaults if not specified in dflash_architecture_config + for attr, default_from_base in [ + ("max_position_embeddings", True), + ("intermediate_size", True), + ("num_attention_heads", True), + ("num_key_value_heads", True), + ("hidden_act", True), + ("rope_theta", True), + ("rope_scaling", True), + ("rope_type", False), + ("position_embedding_type", False), + ("rope_interleaved", False), + ("rms_norm_eps", True), + ("attention_bias", False), + ("tie_word_embeddings", False), + ]: + if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None: + if default_from_base and hasattr(base_config, attr): + setattr(self.dflash_config, attr, getattr(base_config, attr)) + + # Ensure required attrs have defaults + if not hasattr(self.dflash_config, "mlp_bias") or self.dflash_config.mlp_bias is None: + self.dflash_config.mlp_bias = False + + self.dflash_config.head_dim = getattr( + self.dflash_config, + "head_dim", + self.dflash_config.hidden_size // self.dflash_config.num_attention_heads, + ) + self.dflash_config.block_size = self.dflash_block_size + if self.dflash_config._attn_implementation is None: + self.dflash_config._attn_implementation = "eager" + + # Target layer IDs + num_target_layers = base_config.num_hidden_layers + num_draft_layers = self.dflash_config.num_hidden_layers + self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) + self.dflash_config.target_layer_ids = self.target_layer_ids + + # mask_token_id resolution order: + # 1. Explicit in dflash_architecture_config (user override) + # 2. Auto-detect from model vocabulary: + # - Qwen3/3.5: built-in [MASK] token + # - Llama3: reserved_special_token_0 (128002) + # - Others: tokenizer.mask_token_id + # 3. Fallback to pad_token_id or eos_token_id (suboptimal) + mask_id = config.dflash_architecture_config.get("mask_token_id", None) + if mask_id is None: + mask_id = self._auto_detect_mask_token_id(base_config) + self.mask_token_id = mask_id[0] if isinstance(mask_id, list) else mask_id + print(f"DFlash mask_token_id: {self.mask_token_id}") + + # Freeze base model + if self.dflash_freeze_base_model: + for param in self.parameters(): + param.requires_grad = False + + self._find_base_model_parts() + + # Resolve model-specific components (MLP, RMSNorm, RotaryEmbedding) + # from the base model's architecture for weight compatibility + global _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half + _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components( + getattr(base_config, "model_type", "llama") + ) + print(f"DFlash: using {_MLP_CLS.__name__} from {base_config.model_type}") + + self.dflash_module = DFlashModule(self.dflash_config) + self.dflash_module.to(self._base_model.dtype).to( + next(self._base_model.layers[-1].parameters()).device + ) + + self.is_quantized = False + + # Store bound reference to the original model class's forward. + # DynamicModule changes type(self) but the original class is in _original_cls. + # Find the original HF model class (e.g., Qwen3_5ForConditionalGeneration) + # by walking MRO and skipping DFlash/DynamicModule classes + skip_names = { + "HFDFlashModel", + "DFlashModel", + "DynamicModule", + "DFlashPreTrainedModel", + "DFlashDraftModel", + } + original_cls = None + for cls in type(self).__mro__: + if ( + hasattr(cls, "forward") + and cls.__name__ not in skip_names + and cls is not type(self) + and issubclass(cls, PreTrainedModel) + and cls is not PreTrainedModel + ): + original_cls = cls + break + if original_cls is None: + # Last resort: use the class two levels up (skip DFlash wrapper + DynamicModule) + original_cls = type(self).__mro__[2] + self._original_forward_cls = original_cls + print(f"DFlash: using {original_cls.__name__}.forward as base forward") + + def _base_forward(self, **kwargs): + """Call the original model's forward, bypassing DFlash wrapper.""" + return self._original_forward_cls.forward(self, **kwargs) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + """Training forward matching SpecForge OnlineDFlashModel.forward.""" + if not self.training: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + bsz, seq_len = input_ids.shape + block_size = self.dflash_block_size + device = input_ids.device + + # 1. Run base model → raw multi-layer hidden states + # Use super().forward() which goes through DynamicModule → original model + # (same pattern as EAGLE's HFEagleModel) + with torch.no_grad(): + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + # Extract and concatenate target layer hidden states + offset = 1 + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + + # 2. Truncate to multiple of block_size + n_blocks = seq_len // block_size + effective_len = n_blocks * block_size + input_ids_trunc = input_ids[:, :effective_len] + target_hidden = target_hidden[:, :effective_len, :] + # Loss mask: use labels (response-only) if available, else attention_mask (padding) + if labels is not None: + # labels == -100 means "ignore" (system/user tokens when answer_only_loss=True) + loss_mask_input = (labels[:, :effective_len] != -100).float() + elif attention_mask is not None: + loss_mask_input = attention_mask[:, :effective_len].float() + else: + loss_mask_input = torch.ones(bsz, effective_len, device=device) + + # 3. Prepare noise: mask_token_id everywhere, real token at block starts + positions = torch.arange(effective_len, device=device) + is_block_start = (positions % block_size) == 0 + noise_input_ids = torch.full_like(input_ids_trunc, self.mask_token_id) + noise_input_ids[:, is_block_start] = input_ids_trunc[:, is_block_start] + noise_embedding = self._base_model_embeddings(noise_input_ids) + + # 4. Position IDs: [0..L-1, 0..L-1] + pos_seq = torch.arange(effective_len, device=device) + position_ids_2l = torch.cat([pos_seq, pos_seq]).unsqueeze(0).expand(bsz, -1) + + # 5. Attention mask: [1, 1, L, 2L] + dtype = target_hidden.dtype + dflash_attn_mask = create_dflash_attention_mask(effective_len, block_size, device, dtype) + + # 6. Draft forward + hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=position_ids_2l, + attention_mask=dflash_attn_mask, + ) + + # 7. Loss computation + logits = self._base_model_lm_head(hidden) + dflash_loss_mask = create_dflash_loss_mask(effective_len, block_size, device) + combined_mask = loss_mask_input * dflash_loss_mask.unsqueeze(0) + + logits_flat = logits.reshape(-1, logits.size(-1)) + labels_flat = input_ids_trunc.reshape(-1) + mask_flat = combined_mask.reshape(-1) + + active_indices = mask_flat > 0.5 + active_logits = logits_flat[active_indices] + active_labels = labels_flat[active_indices] + + if active_logits.numel() > 0: + if self.dflash_self_logit_distillation: + # Logit distillation: learn from target model's output distribution + # This works regardless of whether training data matches the target model + base_logits_trunc = base_outputs.logits[:, :effective_len, :] + base_logits_flat = base_logits_trunc.reshape(-1, base_logits_trunc.size(-1)) + active_base_logits = base_logits_flat[active_indices].detach() + target_soft = torch.softmax(active_base_logits, dim=-1) + draft_logsoft = torch.log_softmax(active_logits, dim=-1) + loss = -(target_soft * draft_logsoft).sum(dim=-1).mean() + else: + # Hard CE: predict ground truth tokens directly + # Only works well when training data is synthesized by the target model + loss = F.cross_entropy(active_logits, active_labels) + + with torch.no_grad(): + preds = active_logits.argmax(dim=-1) + accuracy = (preds == active_labels).float().mean().item() + else: + # No valid positions — compute a zero loss that still flows through + # dflash_module parameters to keep DDP gradient sync happy + loss = logits.sum() * 0.0 + accuracy = 0.0 + + return ModelOutput( + loss=loss, + logits=base_outputs.logits, + train_acc=[[accuracy]], + ) + + @torch.no_grad() + def pseudo_speculative_generate(self, input_ids, steps=1): + """Generate draft tokens using one DFlash block. + + DFlash generates block_size-1 draft tokens in a single forward pass. + The `steps` parameter is used as the number of tokens to return + (capped at block_size-1). + + Returns: + base_token: Next token from base model [B, 1]. + draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None. + """ + # Call the base model's inner model directly (avoids DynamicModule dispatch) + model_output = self._base_model( + input_ids=input_ids, + output_hidden_states=True, + ) + # Compute logits via lm_head + base_logits = self._base_model_lm_head(model_output.last_hidden_state) + # Build output with hidden_states + base_outputs = ModelOutput( + logits=base_logits, + hidden_states=model_output.hidden_states, + ) + base_logits = base_outputs.logits + base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) + + if steps < 1: + return base_token, None + + # Extract target hidden states (raw, before FC projection) + hid_offset = 1 + if not hasattr(self, "_psg_debug"): + self._psg_debug = True + sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + th_dbg = torch.cat(sel, dim=-1) + n_layers = len(base_outputs.hidden_states) + th_norm = th_dbg.norm().item() + print( + f"[psg] hidden layers: {n_layers}, target_hidden: {th_dbg.shape}, norm: {th_norm:.2f}" + ) + print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") + seq_len = input_ids.shape[1] + blk = self.dflash_block_size + print(f"[psg] pos: ctx=[0..{seq_len - 1}], blk=[{seq_len}..{seq_len + blk - 1}]") + selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) + + block_size = self.dflash_block_size + bsz = input_ids.shape[0] + seq_len = input_ids.shape[1] + device = input_ids.device + dtype = target_hidden.dtype + + # Block: first token is base_token (anchor), rest are mask + block_ids = torch.full( + (bsz, block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_ids[:, 0] = base_token.squeeze(-1) + noise_embedding = self._base_model_embeddings(block_ids) + + # Position IDs: training uses [0..L-1, 0..L-1] where noise positions + # mirror context positions. At inference, block predicts tokens at + # seq_len..seq_len+B-1, so noise positions continue from ctx_len. + ctx_len = target_hidden.shape[1] + ctx_positions = torch.arange(ctx_len, device=device) + block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) + pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) + + # Attention mask: block sees ALL context + reverse-causal within block + # Matching SpecForge training: j >= i (pos 0 sees all, pos B-1 sees only itself) + attn_mask = torch.zeros(1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype) + block_indices = torch.arange(block_size, device=device) + reverse_causal = block_indices.unsqueeze(0) >= block_indices.unsqueeze(1) + noise_mask = torch.zeros(block_size, block_size, device=device, dtype=dtype) + noise_mask.masked_fill_(~reverse_causal, torch.finfo(dtype).min) + attn_mask[:, :, :, ctx_len:] = noise_mask + + # Draft forward + draft_hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=attn_mask, + ) + + # Logits on positions 1..block_size-1 (skip anchor at position 0) + draft_logits = self._base_model_lm_head(draft_hidden[:, 1:, :]) + draft_tokens = draft_logits.argmax(dim=-1) # [B, block_size-1] + + # Return up to `steps` tokens + num_tokens = min(steps, block_size - 1) + return base_token, draft_tokens[:, :num_tokens] diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e147ebf2c..7aecbbfb1 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -186,9 +186,84 @@ def _process_chat_sample(self, examples: list): input_ids = tokenized_examples["input_ids"] labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) labels[..., :-1] = input_ids[..., 1:] + if self.answer_only_loss: + # Try tokenizer's assistant_masks first + if "assistant_masks" in tokenized_examples: + assistant_mask = tokenized_examples["assistant_masks"] + if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): + labels[assistant_mask == 0] = IGNORE_TOKEN_ID + else: + # Fallback: derive from formatted text using regex + labels = self._apply_answer_only_labels(examples, labels, input_ids) + else: + labels = self._apply_answer_only_labels(examples, labels, input_ids) tokenized_examples["labels"] = labels return tokenized_examples + def _apply_answer_only_labels(self, examples, labels, input_ids): + """Derive response-only labels by finding assistant spans in formatted text. + + Uses regex to find assistant response spans in the chat-template-formatted text, + then maps character positions to token positions via offset mapping. + Similar to SpecForge's _apply_loss_mask_from_chat_template. + """ + import re + + for batch_idx, conversation in enumerate(examples): + # Format with chat template + formatted = self.tokenizer.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=False + ) + + # Tokenize with offset mapping + try: + encoding = self.tokenizer( + formatted, + return_offsets_mapping=True, + max_length=self.train_len, + truncation=True, + add_special_tokens=False, + ) + offsets = encoding["offset_mapping"] + except Exception: + # Tokenizer doesn't support offset mapping — keep all labels + continue + + # Find assistant response spans + # Common patterns across chat templates + # Try to detect the assistant marker from the formatted text + assistant_markers = [ + r"<\|im_start\|>assistant\n(.*?)(?:<\|im_end\|>|$)", # Qwen/ChatML + r"<\|start_header_id\|>assistant<\|end_header_id\|>\n\n(.*?)(?:<\|eot_id\|>|$)", # Llama3 + r"\[/INST\](.*?)(?:|$)", # Llama2 + r"assistant\n(.*?)(?:\n\n|$)", # Generic + ] + + found = False + for pattern in assistant_markers: + matches = list(re.finditer(pattern, formatted, re.DOTALL)) + if matches: + # Mask all tokens, then unmask assistant spans + labels[batch_idx, :] = IGNORE_TOKEN_ID + for match in matches: + start_char = match.start(1) + end_char = match.end(1) + for tok_idx, (tok_start, tok_end) in enumerate(offsets): + if tok_idx >= labels.shape[1]: + break + if tok_start >= start_char and tok_end <= end_char: + # Restore the shifted label for this position + if tok_idx < input_ids.shape[1] - 1: + labels[batch_idx, tok_idx] = input_ids[batch_idx, tok_idx + 1] + found = True + break + + if not found: + # No assistant pattern found — keep all labels (don't mask) + pass + + return labels + def _process_text_sample(self, examples: list): tokenized_examples = self.tokenizer( examples, diff --git a/tests/gpu/torch/speculative/plugins/test_hf_dflash.py b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 000000000..230b67c45 --- /dev/null +++ b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for DFlash speculative decoding plugin. + +These tests require a CUDA GPU. CPU-only tests are in tests/unit/. +""" + +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import get_tiny_llama + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, + } + return config + + +@pytest.fixture +def dflash_model(): + """Create a tiny DFlash model on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + return model + + +class TestDFlashModuleGPU: + """Test DFlash draft module forward pass on GPU.""" + + def test_dflash_module_forward_shape(self, dflash_model): + """Test that draft module produces correct output shape.""" + model = dflash_model + bsz = 2 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) + pos_ids = ( + torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]) + .unsqueeze(0) + .expand(bsz, -1) + .cuda() + ) + + output = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + assert output.shape == (bsz, SEQ_LEN, hidden_size) + + def test_dflash_module_deterministic(self, dflash_model): + """Test that draft module produces identical outputs for same input.""" + model = dflash_model + model.eval() + bsz = 1 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) + pos_ids = torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0).cuda() + + with torch.no_grad(): + out1 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + out2 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + assert torch.allclose(out1, out2) + + +class TestDFlashTrainingForwardGPU: + """Test DFlash training forward pass end-to-end on GPU.""" + + @pytest.fixture + def model(self): + """Create a tiny DFlash model in training mode on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + model.train() + return model + + def test_training_forward_returns_loss(self, model): + """Test that training forward returns a differentiable loss.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_returns_accuracy(self, model): + """Test that training forward returns train_acc.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "train_acc") + + def test_training_forward_with_labels(self, model): + """Test that labels are used for response-only loss masking.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + # Labels with -100 for first half (masked), real labels for second half + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + labels[:, SEQ_LEN // 2 :] = input_ids[:, SEQ_LEN // 2 :] + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_all_masked_labels(self, model): + """Test that all-masked labels produce zero loss without crashing.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert output.loss.item() == 0.0 + + def test_training_backward(self, model): + """Test that gradients flow to dflash_module.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + output.loss.backward() + + has_grad = False + for name, param in model.dflash_module.named_parameters(): + if param.grad is not None and param.grad.abs().sum() > 0: + has_grad = True + break + assert has_grad, "DFlash module should receive gradients" + + def test_eval_forward_uses_base_model(self, model): + """In eval mode, forward should use base model (not DFlash training).""" + model.eval() + bsz = 1 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + + with torch.no_grad(): + output = model(input_ids=input_ids) + assert output.logits.shape == (bsz, SEQ_LEN, model.config.vocab_size) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 000000000..50d3c9768 --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for DFlash speculative decoding plugin. + +GPU-dependent tests (training forward, module forward) are in tests/gpu/. +""" + +import os +from copy import deepcopy + +import torch +from _test_utils.torch.transformers_models import ( + get_tiny_llama, + tf_modelopt_state_and_output_tester, +) +from transformers import AutoModelForCausalLM + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG +from modelopt.torch.speculative.plugins.hf_dflash import ( + DFlashModule, + HFDFlashModel, + create_dflash_attention_mask, + create_dflash_loss_mask, +) + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, # use token 0 as mask for tiny model + } + return config + + +class TestDFlashConvert: + """Test DFlash model conversion.""" + + def test_convert_creates_dflash_model(self): + """Test that convert produces an HFDFlashModel.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert isinstance(model, HFDFlashModel) + + def test_convert_creates_dflash_module(self): + """Test that convert attaches a DFlashModule.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "dflash_module") + assert isinstance(model.dflash_module, DFlashModule) + + def test_convert_freezes_base_model(self): + """Test that base model parameters are frozen after convert.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + for name, param in model.named_parameters(): + if "dflash_module" not in name: + assert not param.requires_grad, f"Base param {name} should be frozen" + + def test_convert_dflash_module_trainable(self): + """Test that DFlash module parameters are trainable after convert.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + dflash_params = [(n, p) for n, p in model.named_parameters() if "dflash_module" in n] + assert len(dflash_params) > 0 + for name, param in dflash_params: + assert param.requires_grad, f"DFlash param {name} should be trainable" + + def test_convert_sets_target_layer_ids(self): + """Test that target layer IDs are set correctly.""" + model = get_tiny_llama(num_hidden_layers=8) + config = _get_dflash_config(num_layers=3) + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "target_layer_ids") + assert len(model.target_layer_ids) == 3 + for lid in model.target_layer_ids: + assert 0 <= lid < 8 + + def test_convert_sets_mask_token_id(self): + """Test that mask_token_id is set from config.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "mask_token_id") + assert model.mask_token_id == 0 + + +class TestDFlashSaveRestore: + """Test DFlash model save and restore.""" + + def test_save_and_restore(self, tmp_path): + """Test round-trip save/load preserves modelopt state and outputs.""" + mto.enable_huggingface_checkpointing() + model_ref = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model_ref, [("dflash", config)]) + + model_ref.save_pretrained(tmp_path / "modelopt_model") + assert os.path.exists(tmp_path / "modelopt_model/modelopt_state.pth") + + model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model") + assert isinstance(model_test, HFDFlashModel) + tf_modelopt_state_and_output_tester(model_ref, model_test) + + +class TestDFlashAttentionMask: + """Test DFlash attention mask construction.""" + + def test_mask_shape(self): + """Test mask has shape [1, 1, L, 2L].""" + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + assert mask.shape == (1, 1, SEQ_LEN, 2 * SEQ_LEN) + + def test_mask_context_strictly_previous_blocks(self): + """Context (left half): block B can only see blocks 0..B-1.""" + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] # [8, 16] + ctx_mask = mask_2d[:, :8] # context part + + # Block 0 (rows 0-3) should NOT see any context + assert (ctx_mask[:4, :] < 0).all() + + # Block 1 (rows 4-7) should see block 0 context only + assert (ctx_mask[4:8, :4] == 0).all() # can see block 0 + assert (ctx_mask[4:8, 4:8] < 0).all() # cannot see own block + + def test_mask_noise_causal_within_block(self): + """Noise (right half): reverse-causal within same block, matching SpecForge. + + SpecForge uses j >= i: position 0 (anchor) sees all positions in block, + position B-1 sees only itself. Cross-block noise is fully masked. + """ + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] + noise_mask = mask_2d[:, 8:] # noise part + + # Block 0, position 0: can see all positions in block (0-3) + assert (noise_mask[0, :4] == 0).all() + + # Block 0, position 3: can only see position 3 + assert (noise_mask[3, :3] < 0).all() + assert noise_mask[3, 3] == 0 + + # Block 1 cannot see block 0 noise + assert (noise_mask[4:8, :4] < 0).all() + + def test_mask_values_are_zero_or_neg_inf(self): + """Test mask contains only 0 (attend) and -inf (mask).""" + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + unique_vals = mask.unique() + assert len(unique_vals) == 2 + assert 0.0 in unique_vals + assert unique_vals.min() == torch.finfo(torch.float32).min + + +class TestDFlashLossMask: + """Test DFlash loss mask construction.""" + + def test_loss_mask_shape(self): + """Test loss mask has shape [L].""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert mask.shape == (SEQ_LEN,) + + def test_loss_mask_excludes_block_zero(self): + """Test all positions in block 0 are masked out.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert (mask[:BLOCK_SIZE] == 0).all() + + def test_loss_mask_excludes_block_starts(self): + """Test block start positions are masked.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + for i in range(0, SEQ_LEN, BLOCK_SIZE): + assert mask[i] == 0, f"Block start position {i} should be masked" + + def test_loss_mask_includes_non_start_positions(self): + """Test non-start positions in non-zero blocks are included.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + for b in range(1, SEQ_LEN // BLOCK_SIZE): + for offset in range(1, BLOCK_SIZE): + pos = b * BLOCK_SIZE + offset + assert mask[pos] == 1, f"Position {pos} should be in loss" + + def test_loss_mask_count(self): + """Test total active positions matches expected count.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + num_blocks = SEQ_LEN // BLOCK_SIZE + expected = (num_blocks - 1) * (BLOCK_SIZE - 1) + assert mask.sum().item() == expected + + +class TestBuildTargetLayerIds: + """Test target layer selection.""" + + def test_single_draft_layer(self): + """Test single draft layer selects middle target layer.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(32, 1) + assert len(ids) == 1 + assert ids[0] == 16 # middle layer + + def test_multiple_draft_layers(self): + """Test multiple draft layers are monotonically increasing and in bounds.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(36, 5) + assert len(ids) == 5 + assert ids == sorted(ids) + assert all(1 <= lid <= 33 for lid in ids) + + def test_layer_ids_spread(self): + """Test layer IDs have no duplicates.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(32, 5) + assert len(ids) == 5 + assert len(set(ids)) == 5 diff --git a/tools/launcher/common/dflash/ar_validate.sh b/tools/launcher/common/dflash/ar_validate.sh new file mode 100644 index 000000000..01ad61ffd --- /dev/null +++ b/tools/launcher/common/dflash/ar_validate.sh @@ -0,0 +1,118 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash AR (Acceptance Rate) validation script. +# Loads a trained DFlash checkpoint and evaluates speculative decoding AR on MT-Bench. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# DFLASH_CKPT — path to the trained DFlash checkpoint +# DFLASH_BLOCK_SIZE — block size (default: 16) +# DFLASH_NUM_LAYERS — number of draft layers (default: 5) +# DFLASH_MASK_TOKEN_ID — mask token ID (default: auto-detect) +# NUM_SAMPLES — number of MT-Bench samples to evaluate (default: 20) + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh +trap 'error_handler $0 $LINENO' ERR + +DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} +DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} +NUM_SAMPLES=${NUM_SAMPLES:-20} + +# Build mask_token_id arg +if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then + MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +else + MASK_ARG="" +fi + +echo "=== DFlash AR Validation ===" +echo "Target model: ${HF_MODEL_CKPT}" +echo "DFlash checkpoint: ${DFLASH_CKPT}" +echo "Block size: ${DFLASH_BLOCK_SIZE}" +echo "Draft layers: ${DFLASH_NUM_LAYERS}" +echo "Samples: ${NUM_SAMPLES}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.speculative.plugins.transformers import HFARValidation +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + +mto.enable_huggingface_checkpointing() + +model = AutoModelForCausalLM.from_pretrained( + '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) + +config = { + 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, + 'dflash_architecture_config': { + 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, + ${MASK_ARG} + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load trained DFlash weights +import glob +from safetensors.torch import load_file +ckpt_files = sorted(glob.glob('${DFLASH_CKPT}/model*.safetensors')) +if ckpt_files: + state = {} + for f in ckpt_files: + state.update(load_file(f)) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights from {len(ckpt_files)} file(s)') +else: + print('WARNING: No checkpoint files found, using random weights') + +model.eval() +validator = HFARValidation(model, tokenizer) + +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +num_samples = min(${NUM_SAMPLES}, len(ds)) + +ars = [] +for i in range(num_samples): + prompt = ds[i]['prompt'][0] + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) + ars.append(ar) + print(f' AR={ar:.2f} | {prompt[:60]}') + except Exception as e: + print(f' ERROR | {prompt[:60]}... | {e}') + +if ars: + avg_ar = sum(ars) / len(ars) + print(f'\n==== DFlash AR Results ====') + print(f'Samples: {len(ars)}') + print(f'Average AR: {avg_ar:.4f}') + print(f'Min AR: {min(ars):.4f}') + print(f'Max AR: {max(ars):.4f}') +else: + print('No AR results collected') +" diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh new file mode 100644 index 000000000..0cdb6a906 --- /dev/null +++ b/tools/launcher/common/dflash/online_training.sh @@ -0,0 +1,163 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash online training + AR validation script for the ModelOpt Launcher. +# Trains a DFlash draft model alongside the frozen target model, +# then evaluates acceptance rate on MT-Bench. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# +# Optional env vars: +# NUM_AR_SAMPLES — number of MT-Bench samples for AR validation (default: 20, 0 to skip) +# +# All other args are passed through to launch_train.sh. + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt +pip install huggingface-hub>=1.2.1 +export PATH=$PATH:/workspace/.local/bin + +################################################################################################### + +trap 'error_handler $0 $LINENO' ERR + +# Parse DFlash-specific args from the command line for AR validation +DFLASH_BLOCK_SIZE=16 +DFLASH_NUM_LAYERS=5 +DFLASH_MASK_TOKEN_ID="" +OUTPUT_DIR="" +for arg in "$@"; do + case "$arg" in + --dflash_block_size) next_is_block_size=1 ;; + --dflash_num_layers) next_is_num_layers=1 ;; + --dflash_mask_token_id) next_is_mask_id=1 ;; + --output_dir) next_is_output=1 ;; + *) + if [ "$next_is_block_size" = "1" ]; then DFLASH_BLOCK_SIZE="$arg"; next_is_block_size=0; fi + if [ "$next_is_num_layers" = "1" ]; then DFLASH_NUM_LAYERS="$arg"; next_is_num_layers=0; fi + if [ "$next_is_mask_id" = "1" ]; then DFLASH_MASK_TOKEN_ID="$arg"; next_is_mask_id=0; fi + if [ "$next_is_output" = "1" ]; then OUTPUT_DIR="$arg"; next_is_output=0; fi + ;; + esac +done + +# Step 1: Training +bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ + --model ${HF_MODEL_CKPT} \ + --mode dflash \ + ${@} + +# Step 2: AR Validation +NUM_AR_SAMPLES=${NUM_AR_SAMPLES:-20} +if [ "${NUM_AR_SAMPLES}" = "0" ]; then + echo "Skipping AR validation (NUM_AR_SAMPLES=0)" + exit 0 +fi + +if [ -z "$OUTPUT_DIR" ]; then + echo "WARNING: --output_dir not found in args, skipping AR validation" + exit 0 +fi + +# Build mask_token_id config +if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then + MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +else + MASK_ARG="" +fi + +echo "" +echo "=== DFlash AR Validation ===" +echo "Target model: ${HF_MODEL_CKPT}" +echo "DFlash checkpoint: ${OUTPUT_DIR}" +echo "Block size: ${DFLASH_BLOCK_SIZE}" +echo "Draft layers: ${DFLASH_NUM_LAYERS}" +echo "Samples: ${NUM_AR_SAMPLES}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.speculative.plugins.transformers import HFARValidation +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + +mto.enable_huggingface_checkpointing() + +model = AutoModelForCausalLM.from_pretrained( + '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) + +config = { + 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, + 'dflash_architecture_config': { + 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, + ${MASK_ARG} + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load trained DFlash weights +import glob +from safetensors.torch import load_file +ckpt_files = sorted(glob.glob('${OUTPUT_DIR}/model*.safetensors')) +if ckpt_files: + state = {} + for f in ckpt_files: + state.update(load_file(f)) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights from {len(ckpt_files)} file(s)') +else: + print('WARNING: No checkpoint files found, using random weights') + +model.eval() +validator = HFARValidation(model, tokenizer) + +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +num_samples = min(${NUM_AR_SAMPLES}, len(ds)) + +ars = [] +for i in range(num_samples): + prompt = ds[i]['prompt'][0] + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) + ars.append(ar) + print(f' AR={ar:.2f} | {prompt[:60]}') + except Exception as e: + print(f' ERROR | {prompt[:60]}... | {e}') + +if ars: + avg_ar = sum(ars) / len(ars) + print(f'\n==== DFlash AR Results ====') + print(f'Samples: {len(ars)}') + print(f'Average AR: {avg_ar:.4f}') + print(f'Min AR: {min(ars):.4f}') + print(f'Max AR: {max(ars):.4f}') +else: + print('No AR results collected') +" + +################################################################################################### diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml new file mode 100644 index 000000000..c72b5aec4 --- /dev/null +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -0,0 +1,63 @@ +# DFlash online speculative decoding training for Qwen3-8B. +# +# Trains a DFlash draft model (block diffusion) using the frozen target model +# to extract multi-layer hidden states on the fly, then evaluates AR on MT-Bench. +# +# 2-step pipeline: +# task_0: Online DFlash training + AR validation +# task_1: Benchmark speculative decoding speedup via VLLM +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes + +job_name: Qwen3-8B_DFlash_online +pipeline: + # Step 1: Online DFlash training + AR validation + task_0: + script: common/dflash/online_training.sh + args: + - --data /hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-100K.jsonl + - --output_dir /scratchspace/dflash + - --num_epochs 3 + - --lr 1e-4 + - --training_seq_len 512 + - --save_steps 500000 + - --log_steps 100 + - --disable_tqdm True + - --ar_validate_steps 0 + - --dflash_block_size 16 + - --dflash_num_layers 5 + - --dflash_mask_token_id 151643 + environment: + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - NUM_AR_SAMPLES: "20" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + + # Step 2: Benchmark speculative decoding (VLLM backend) + task_1: + script: common/specdec_bench/quick_check.sh + args: + - --draft_model_dir /scratchspace/dflash + - --draft_length 3 + - --output_length 4096 + - --engine VLLM + - --tp_size 4 + - --ep_size 1 + - --speculative_algorithm EAGLE3 + - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl + - --concurrency 1 + environment: + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/uv.lock b/uv.lock index d890e361c..3cd6db208 100644 --- a/uv.lock +++ b/uv.lock @@ -20,9 +20,6 @@ resolution-markers = [ "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] -[manifest] -overrides = [{ name = "torch", marker = "sys_platform == 'never'" }] - [[package]] name = "accelerate" version = "1.13.0" @@ -35,7 +32,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } wheels = [ @@ -480,6 +477,21 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/54/27/01d9078a77b9e31b79b9716e66ca4db74f4744c5232bcb3e8769395c4280/cppimport-22.8.2.tar.gz", hash = "sha256:bbb4957102db41bc99ad72c233bce92f9d1fd91be352fc07878c4361033a401f", size = 26635, upload-time = "2022-08-02T16:50:36.872Z" } +[[package]] +name = "cuda-bindings" +version = "12.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" }, + { url = "https://files.pythonhosted.org/packages/45/e7/b47792cc2d01c7e1d37c32402182524774dadd2d26339bd224e0e913832e/cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c912a3d9e6b6651853eed8eed96d6800d69c08e94052c292fec3f282c5a817c9", size = 12210593, upload-time = "2025-10-21T14:51:36.574Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, + { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, + { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, +] + [[package]] name = "cuda-pathfinder" version = "1.4.3" @@ -554,7 +566,7 @@ dependencies = [ { name = "psutil", marker = "sys_platform != 'win32'" }, { name = "py-cpuinfo", marker = "sys_platform != 'win32'" }, { name = "pydantic", marker = "sys_platform != 'win32'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch", marker = "sys_platform != 'win32'" }, { name = "tqdm", marker = "sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/19/11/46b9eb3806ca7a5e9bdddb7e873855a2d59a9f87f0675ae8231678d98434/deepspeed-0.18.8.tar.gz", hash = "sha256:e4e051a144b0c74270c46e4970139f9a86a61ff26959c5e463000c4a93b99304", size = 1647226, upload-time = "2026-03-13T18:49:48.568Z" } @@ -1311,7 +1323,9 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version < '3.11' and sys_platform == 'darwin')", + "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'win32'", "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } @@ -1324,12 +1338,18 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version >= '3.13' and sys_platform == 'darwin')", "(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.12.*' and sys_platform == 'darwin')", "(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.11.*' and sys_platform == 'darwin')", "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1526,6 +1546,108 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/a7/b35835e278c18b85206834b3aa3abe68e77a98769c59233d1f6300284781/numpy-2.4.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:4b42639cdde6d24e732ff823a3fa5b701d8acad89c4142bc1d0bd6dc85200ba5", size = 12504685, upload-time = "2026-03-09T07:58:50.525Z" }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + [[package]] name = "nvidia-ml-py" version = "13.595.45" @@ -1554,7 +1676,7 @@ dependencies = [ { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "setuptools" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, ] @@ -1756,11 +1878,43 @@ requires-dist = [ { name = "tox", marker = "extra == 'dev-test'", specifier = ">4.18" }, { name = "tox-current-env", marker = "extra == 'dev-test'", specifier = ">=0.0.12" }, { name = "tqdm" }, - { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53,<5.0" }, + { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.56,<5.0" }, { name = "wonderwords", marker = "extra == 'hf'" }, ] provides-extras = ["onnx", "hf", "dev-lint", "dev-docs", "dev-test", "all", "dev"] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu12" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + [[package]] name = "omegaconf" version = "2.3.0" @@ -2157,7 +2311,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, ] @@ -3403,7 +3557,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d7/2c/593109822fe735e637382aca6640c1102c19797f7791f1fd1dab2d6c3cb1/timm-1.0.25.tar.gz", hash = "sha256:47f59fc2754725735cc81bb83bcbfce5bec4ebd5d4bb9e69da57daa92fcfa768", size = 2414743, upload-time = "2026-02-23T16:49:00.137Z" } @@ -3491,15 +3645,63 @@ name = "torch" version = "2.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/30/bfebdd8ec77db9a79775121789992d6b3b75ee5494971294d7b4b7c999bc/torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313", size = 79411457, upload-time = "2026-02-10T21:44:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, + { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, + { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/16/ee/efbd56687be60ef9af0c9c0ebe106964c07400eade5b0af8902a1d8cd58c/torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321", size = 915510070, upload-time = "2026-03-11T14:16:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, + { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, + { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" }, + { url = "https://files.pythonhosted.org/packages/76/bb/d820f90e69cda6c8169b32a0c6a3ab7b17bf7990b8f2c680077c24a3c14c/torch-2.10.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:35e407430795c8d3edb07a1d711c41cc1f9eaddc8b2f1cc0a165a6767a8fb73d", size = 79411450, upload-time = "2026-01-21T16:25:30.692Z" }, + { url = "https://files.pythonhosted.org/packages/78/89/f5554b13ebd71e05c0b002f95148033e730d3f7067f67423026cc9c69410/torch-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3282d9febd1e4e476630a099692b44fdc214ee9bf8ee5377732d9d9dfe5712e4", size = 145992610, upload-time = "2026-01-21T16:25:26.327Z" }, + { url = "https://files.pythonhosted.org/packages/ae/30/a3a2120621bf9c17779b169fc17e3dc29b230c29d0f8222f499f5e159aa8/torch-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a2f9edd8dbc99f62bc4dfb78af7bf89499bca3d753423ac1b4e06592e467b763", size = 915607863, upload-time = "2026-01-21T16:25:06.696Z" }, + { url = "https://files.pythonhosted.org/packages/6f/3d/c87b33c5f260a2a8ad68da7147e105f05868c281c63d65ed85aa4da98c66/torch-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:29b7009dba4b7a1c960260fc8ac85022c784250af43af9fb0ebafc9883782ebd", size = 113723116, upload-time = "2026-01-21T16:25:21.916Z" }, + { url = "https://files.pythonhosted.org/packages/61/d8/15b9d9d3a6b0c01b883787bd056acbe5cc321090d4b216d3ea89a8fcfdf3/torch-2.10.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:b7bd80f3477b830dd166c707c5b0b82a898e7b16f59a7d9d42778dd058272e8b", size = 79423461, upload-time = "2026-01-21T16:24:50.266Z" }, + { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, + { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, + { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" }, + { url = "https://files.pythonhosted.org/packages/c9/5c/dee910b87c4d5c0fcb41b50839ae04df87c1cfc663cf1b5fca7ea565eeaa/torch-2.10.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6d3707a61863d1c4d6ebba7be4ca320f42b869ee657e9b2c21c736bf17000294", size = 79498198, upload-time = "2026-01-21T16:24:34.704Z" }, + { url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" }, + { url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, + { url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" }, + { url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, + { url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" }, + { url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, + { url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" }, + { url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, +] [[package]] name = "torch-geometric" @@ -3529,7 +3731,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6f/36/574c0c46e818533b78b3c09505211162918188325ab4165ef11a3f295755/torchprofile-0.0.4.tar.gz", hash = "sha256:96b6da17d752a06b02977e078aea95614893b31d4117dd5dcd081f30ce65611b", size = 4557, upload-time = "2021-06-22T04:58:03.592Z" } @@ -3545,7 +3747,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pillow" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/50/ae/cbf727421eb73f1cf907fbe5788326a08f111b3f6b6ddca15426b53fec9a/torchvision-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a95c47abb817d4e90ea1a8e57bd0d728e3e6b533b3495ae77d84d883c4d11f56", size = 1874919, upload-time = "2026-01-21T16:27:47.617Z" }, @@ -3638,6 +3840,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" }, ] +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0"