From 6c038f97c2b523ae4fdc5703457589a1b960f87a Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 27 Oct 2025 17:42:17 +0530 Subject: [PATCH 01/62] Add modelopt/torch/_compress CODEOWNERS Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c74be84985..3c3d40a000 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -17,6 +17,7 @@ modelopt/deploy @NVIDIA/modelopt-deploy-codeowners modelopt/onnx @NVIDIA/modelopt-onnx-codeowners modelopt/onnx/autocast @NVIDIA/modelopt-onnx-autocast-codeowners modelopt/torch @NVIDIA/modelopt-torch-codeowners +modelopt/torch/_compress @NVIDIA/modelopt-torch-compress-codeowners modelopt/torch/_deploy @NVIDIA/modelopt-torch-deploy-codeowners modelopt/torch/distill @NVIDIA/modelopt-torch-distill-codeowners modelopt/torch/export @NVIDIA/modelopt-torch-export-codeowners From 54c5f0fdb7c0fd83a96fe4b8f59f7e8314c1e8f7 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 27 Oct 2025 12:37:45 -0700 Subject: [PATCH 02/62] Remove llm_ptq example tests from CICD Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/workflows/example_tests.yml | 103 ---------------------------- 1 file changed, 103 deletions(-) delete mode 100644 .github/workflows/example_tests.yml diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml deleted file mode 100644 index 21eba5f864..0000000000 --- a/.github/workflows/example_tests.yml +++ /dev/null @@ -1,103 +0,0 @@ -# NOTE: Make sure this file is consistent with .gitlab/tests.yml -name: E2E Example tests - -on: - push: - branches: ["pull-request/[0-9]+"] - # NOTE: paths cannot be used since push happens to copied PR and only latest commit to PR is used - schedule: - - cron: "0 0 * * *" # Nightly - workflow_dispatch: # On-demand - -# Cancel previous runs if new commit is pushed to the same PR -concurrency: - group: ${{ github.workflow }}-${{ startsWith(github.ref, 'refs/heads/pull-request/') && github.ref || github.sha }} - cancel-in-progress: true - -jobs: - check-file-changes: - if: startsWith(github.ref, 'refs/heads/pull-request/') - runs-on: ubuntu-latest - outputs: - any_changed: ${{ steps.changed-tests.outputs.any_changed }} - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - id: get-pr-info - uses: nv-gha-runners/get-pr-info@main - # Get commit from main branch that is present in the PR to use as base for changed files - - id: calculate-merge-base - env: - PR_SHA: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).head.sha }} - BASE_SHA: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).base.sha }} - run: | - (echo -n "merge-base="; git merge-base "$BASE_SHA" "$PR_SHA") | tee --append "${GITHUB_OUTPUT}" - - name: Check for changes in test-relevant directories - id: changed-tests - uses: step-security/changed-files@v46.0.5 - with: - base_sha: ${{ steps.calculate-merge-base.outputs.merge-base }} - sha: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).head.sha }} - files: | - .github/workflows/example_tests.yml - examples/llm_ptq/** - modelopt/torch/** - tests/examples/llm_ptq/** - setup.py - fail_on_initial_diff_error: true - wait-checks: - needs: [check-file-changes] - if: needs.check-file-changes.outputs.any_changed == 'true' - uses: ./.github/workflows/_wait_for_checks.yml - permissions: - checks: read - secrets: inherit - with: - match_pattern: '^DCO$|^linux$' # Wait for DCO and Unit tests / linux to pass - delay: 300s - example-tests-pr: - needs: [check-file-changes, wait-checks] - if: needs.check-file-changes.outputs.any_changed == 'true' - # Runner list at https://github.com/nv-gha-runners/enterprise-runner-configuration/blob/main/docs/runner-groups.md - runs-on: linux-amd64-gpu-h100-latest-1 - timeout-minutes: 90 - strategy: - matrix: - EXAMPLE: [llm_ptq] - container: &example_container - image: nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2.post2 - env: - PIP_CONSTRAINT: "" # Disable pip constraint for upgrading packages - HF_TOKEN: ${{ secrets.HF_TOKEN }} - steps: &example_steps - - uses: actions/checkout@v4 - - uses: nv-gha-runners/setup-proxy-cache@main - - name: Setup environment variables - run: | - echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu:/usr/local/tensorrt/targets/x86_64-linux-gnu/lib" >> $GITHUB_ENV - echo "PATH=${PATH}:/usr/local/tensorrt/targets/x86_64-linux-gnu/bin" >> $GITHUB_ENV - - name: Run example tests - run: | - pip install ".[hf,dev-test]" - find examples/${{ matrix.EXAMPLE }} -name "requirements.txt" | while read req_file; do pip install -r "$req_file" || exit 1; done - pytest -s tests/examples/${{ matrix.EXAMPLE }} - example-tests-non-pr: - if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} - # Runner list at https://github.com/nv-gha-runners/enterprise-runner-configuration/blob/main/docs/runner-groups.md - runs-on: linux-amd64-gpu-h100-latest-1 - timeout-minutes: 90 - strategy: - matrix: - EXAMPLE: [llm_ptq] - container: *example_container - steps: *example_steps - example-pr-required-check: - # Run even if example-tests-pr is skipped - if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }} - needs: [check-file-changes, example-tests-pr] - runs-on: ubuntu-latest - steps: - - name: Required GPU tests did not succeed - if: ${{ needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && needs.example-tests-pr.result != 'success') }} - run: exit 1 From 9eeee251cf6fbe1fe62431efc44b9b183673c738 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 28 Oct 2025 20:46:15 +0100 Subject: [PATCH 03/62] E2E test for the experimental compress algorithm based on https://arxiv.org/abs/2411.19146 (#464) Signed-off-by: Daniel Korzekwa Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/_compress/README.md | 3 + modelopt/torch/_compress/compress.py | 82 +++ modelopt/torch/_compress/runtime.py | 556 ++++++++++++++++++ tests/_test_utils/torch_dist/dist_utils.py | 5 + .../resources/configs/Llama-3_1-8B.yaml | 108 ++++ .../configs/pruning/attn_pruning.yaml | 16 + .../configs/pruning/ffn_pruning.yaml | 12 + .../configs/pruning/hidden_dim_pruning.yaml | 15 + .../configs/pruning/pruning_defaults.yaml | 32 + .../configs/validate_model_defaults.yaml | 15 + .../configs/validate_solutions_defaults.yaml | 10 + .../tokenizer/special_tokens_map.json | 16 + .../resources/tokenizer/tokenizer.json | 212 +++++++ .../resources/tokenizer/tokenizer_config.json | 13 + .../resources/tokenizer/truncate_tokenizer.py | 62 ++ .../torch/_compress/test_compress.py | 240 ++++++++ 16 files changed, 1397 insertions(+) create mode 100644 modelopt/torch/_compress/README.md create mode 100644 modelopt/torch/_compress/compress.py create mode 100644 modelopt/torch/_compress/runtime.py create mode 100644 tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml create mode 100644 tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml create mode 100644 tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml create mode 100644 tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml create mode 100644 tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml create mode 100644 tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml create mode 100644 tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml create mode 100644 tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json create mode 100644 tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json create mode 100644 tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json create mode 100644 tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py create mode 100644 tests/experimental/torch/_compress/test_compress.py diff --git a/modelopt/torch/_compress/README.md b/modelopt/torch/_compress/README.md new file mode 100644 index 0000000000..4c6da80e54 --- /dev/null +++ b/modelopt/torch/_compress/README.md @@ -0,0 +1,3 @@ +Experimental model compression algorithm based on a Local Neural Architecture Search. +Based on the Puzzle paper: +PoC for Llama 3.1 model. diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py new file mode 100644 index 0000000000..265fd5eeb2 --- /dev/null +++ b/modelopt/torch/_compress/compress.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + +This module provides the main compression function for a model +using MIP-based NAS search algorithm. + +""" + +import build_library_and_stats +import mip_and_realize_models +import pruning_ckpts +import score_pruning_activations +import scoring +from omegaconf import DictConfig +from puzzle_tools.runtime import IRuntime + +# TODO Move initialize_hydra_config_for_dir from tests to main +from tests.utils.test_utils import initialize_hydra_config_for_dir + + +def compress( + hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str, runtime: IRuntime +) -> DictConfig: + """Compress a puzzletron model using the MIP-based NAS search algorithm. + + Args: + hydra_config_dir (str): path to a hydra_config_dir that defines the search space + hydra_config (str): the corresponding hydra config file + puzzle_dir (str): directory with a puzzletron model to compress + dataset_path (str): dataset used for scoring and distillation + runtime: distributed runtime to use to run the compression steps, e.g., + NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)) + + Returns: + Hydra config object after compressing the model. + The same hydra configuration object is used across all compression steps. + @TODO: Investigate if this config object is immutable across steps and clarify + """ + # Step 0: Load puzzletron hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config, + overrides=[ + f"puzzle_dir={puzzle_dir}", + f"dataset_path={dataset_path}", + ], + ) + + # Step 1: score_pruning_activations (distributed processing) + score_pruning_activations.launch_score_activations(hydra_cfg, runtime) + + # Step 2: pruning_ckpts (single process) + if runtime.global_rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + runtime.wait_for_everyone() + + # Step 4: build_library_and_stats (single process) + if runtime.global_rank == 0: + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + runtime.wait_for_everyone() + + # Step 5: calc_one_block_scores (distributed processing) + scoring.launch_scoring(hydra_cfg, runtime) + + # Step 6: mip_and_realize_models (distributed processing) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) + + return hydra_cfg diff --git a/modelopt/torch/_compress/runtime.py b/modelopt/torch/_compress/runtime.py new file mode 100644 index 0000000000..46f561a5d9 --- /dev/null +++ b/modelopt/torch/_compress/runtime.py @@ -0,0 +1,556 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for torch distributed runtime management""" + +import os +import random +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Sequence +from contextlib import AbstractContextManager, suppress +from datetime import timedelta +from pathlib import Path +from typing import Literal, TypeVar, cast + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +from torch.utils.data import DataLoader +from tqdm import tqdm +from typing_extensions import override + +PrepareModelsT = TypeVar("PrepareModelsT", bound=Sequence[nn.Module]) +PrepareDataLoaderT = TypeVar("PrepareDataLoaderT", bound=DataLoader) +CompileT = TypeVar("CompileT", bound=nn.Module) +Filter = ( + Literal["main_process", "last", "local_main_process", "local_last", "all"] + | list[int] + | set[int] + | Callable[[int], bool] +) + + +class IRuntime(ABC): + @abstractmethod + def setup(self) -> None: ... + + @abstractmethod + def cleanup(self) -> None: ... + + @abstractmethod + def autocast(self) -> AbstractContextManager: ... + + @abstractmethod + def wait_for_everyone(self) -> None: ... + + @abstractmethod + def set_seed(self, seed: int, device_specific: bool = False) -> int: ... + + @abstractmethod + def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: ... + + @abstractmethod + def prepare_train_dataloader( + self, train_dataloader: PrepareDataLoaderT + ) -> PrepareDataLoaderT: ... + + @abstractmethod + def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: ... + + @abstractmethod + def compile(self, model: CompileT) -> CompileT: ... + + @abstractmethod + def backward(self, loss: torch.Tensor) -> None: ... + + @abstractmethod + def clip_grad_norm_( + self, + parameters: Iterable[torch.Tensor] | torch.Tensor, + max_norm: float, + norm_type: float = 2, + ) -> torch.Tensor: ... + + @abstractmethod + def clip_grad_value_( + self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float + ) -> None: ... + + @abstractmethod + def save_state(self, path: str | Path) -> None: ... + + @abstractmethod + def load_state(self, path: str | Path) -> None: ... + + @abstractmethod + def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: ... + + @property + @abstractmethod + def sync_gradients(self) -> bool: ... + + @property + @abstractmethod + def device(self) -> torch.device: ... + + @property + @abstractmethod + def is_main_process(self) -> bool: ... + + @property + @abstractmethod + def is_local_main_process(self) -> bool: ... + + @property + @abstractmethod + def is_last_process(self) -> bool: ... + + @property + @abstractmethod + def is_local_last_process(self) -> bool: ... + + @property + @abstractmethod + def local_rank(self) -> int: ... + + @property + @abstractmethod + def global_rank(self) -> int: ... + + @property + @abstractmethod + def local_world_size(self) -> int: ... + + @property + @abstractmethod + def world_size(self) -> int: ... + + @property + @abstractmethod + def dtype(self) -> torch.dtype: ... + + def __enter__(self): + self.setup() + return self + + def __exit__(self, exc_type, exc_value, traceback): + # avoid barrier if exceution errored + if exc_type is None: + self.cleanup() + + # if exc_type is not None: + # raise exc_value + # Handle exceptions if necessary + # pass + + # def __del__(self): + # torch.distributed.barrier() + # torch.distributed.destroy_process_group() + + def check_filter(self, filter_: Filter): + return ( + filter_ == "all" + or (filter_ == "main_process" and self.is_main_process) + or (filter_ == "local_main_process" and self.is_local_main_process) + or (filter_ == "last" and self.is_last_process) + or (filter_ == "local_last" and self.is_local_last_process) + or (isinstance(filter_, (list, set)) and self.global_rank in filter_) + or (callable(filter_) and filter_(self.global_rank)) + ) + + def print( + self, *args, filter_: Filter = "main_process", rank_prefix=False, flush=True, **kwargs + ) -> None: + if not self.check_filter(filter_): + return + + if rank_prefix: + print(f"[global_rank={self.global_rank}]", *args, flush=flush, **kwargs) + else: + print(*args, flush=flush, **kwargs) + + def process_print( + self, *args, filter_: Filter = "all", rank_prefix=True, flush=True, **kwargs + ) -> None: + if not self.check_filter(filter_): + return + + if rank_prefix: + prefix = f"[global_rank={self.global_rank}]" + if len(args) == 1: # avoid out-of-order printing if possible + out = f"{prefix} {args[0]}" + args = (out,) + else: + args = (prefix, *args) + print(*args, flush=flush, **kwargs) + else: + print(*args, flush=flush, **kwargs) + + +class NativeDdpRuntime(IRuntime): + def __init__( + self, + dtype: torch.dtype = torch.float, + torch_distributed_timeout: timedelta | None = None, + ): + self._master_addr = os.environ["MASTER_ADDR"] + self._master_port = int(os.environ["MASTER_PORT"]) + self._local_rank = int(os.environ["LOCAL_RANK"]) + self._global_rank = int(os.environ["RANK"]) + self._local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + self._world_size = int(os.environ["WORLD_SIZE"]) + self._device = torch.device(self.local_rank) + self._dtype = dtype + self._torch_distributed_timeout = torch_distributed_timeout + + @override + def setup(self): + torch.cuda.set_device(self._device) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + "cpu:gloo,cuda:nccl", timeout=self._torch_distributed_timeout + ) + input_tensors = [ + torch.tensor([0], dtype=torch.float32, device=self._device) + for _ in range(self.world_size) + ] + output_tensors = [ + torch.tensor([0], dtype=torch.float32, device=self._device) + for _ in range(self.world_size) + ] + torch.distributed.all_to_all(input_tensors, output_tensors) + + @override + def cleanup(self): + with suppress(Exception): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + @override + def autocast(self) -> AbstractContextManager: + result = torch.autocast(device_type="cuda", dtype=self._dtype, enabled=True) + return result + + @override + def wait_for_everyone(self): + torch.distributed.barrier() + + @override + def set_seed(self, seed: int, device_specific: bool = False) -> int: + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): + The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + """ + if device_specific: + seed += self.global_rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + return seed + + @override + def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: + assert all(isinstance(x, nn.Module) for x in models) + new_models = [nn.parallel.DistributedDataParallel(m) for m in models] + new_models = cast("PrepareModelsT", new_models) + return new_models # type: ignore[return-value] + + @override + def prepare_train_dataloader(self, train_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: + return train_dataloader + + @override + def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: + return val_dataloader + + @override + def compile(self, model: CompileT) -> CompileT: + result = torch.compile(model) + result = cast("CompileT", result) + return result + + @override + def backward(self, loss: torch.Tensor) -> None: + loss.backward() + + @override + def clip_grad_norm_( + self, + parameters: Iterable[torch.Tensor] | torch.Tensor, + max_norm: float, + norm_type: float = 2, + ) -> torch.Tensor: + result = torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) + return result + + @override + def clip_grad_value_( + self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float + ) -> None: + torch.nn.utils.clip_grad_value_(parameters, clip_value) + + @override + def save_state(self, path: str | Path) -> None: + pass + + @override + def load_state(self, path: str | Path) -> None: + pass + + @override + def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: + for _ in tqdm( + range(num_batches), desc=f"rank {self._global_rank}: skip_first_batches({num_batches=})" + ): + next(dataloader_iterator) + + @property + @override + def sync_gradients(self) -> bool: + return True + + @property + @override + def is_main_process(self) -> bool: + result = self.global_rank == 0 + return result + + @property + @override + def is_local_main_process(self) -> bool: + result = self.local_rank == 0 + return result + + @property + @override + def is_last_process(self) -> bool: + result = self.global_rank == self.world_size - 1 + return result + + @property + @override + def is_local_last_process(self) -> bool: + result = self.local_rank == self.local_world_size - 1 + return result + + @property + @override + def local_rank(self) -> int: + return self._local_rank + + @property + @override + def global_rank(self) -> int: + return self._global_rank + + @property + @override + def local_world_size(self) -> int: + return self._local_world_size + + @property + @override + def world_size(self) -> int: + return self._world_size + + @property + @override + def device(self) -> torch.device: + return self._device + + @property + @override + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def master_addr(self) -> str: + return self._master_addr + + @property + def master_port(self) -> int: + return self._master_port + + +class BaseRuntime(IRuntime): + def __init__(self, dtype: torch.dtype = torch.float): + self._device = torch.device(self.local_rank) + self._dtype = dtype + + @override + def setup(self): + torch.cuda.set_device(self._device) + + @override + def cleanup(self): ... + + @override + def autocast(self) -> AbstractContextManager: + result = torch.autocast(device_type="cuda", dtype=self._dtype, enabled=True) + return result + + @override + def wait_for_everyone(self): ... + + @override + def set_seed(self, seed: int, device_specific: bool = False) -> int: + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): + The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + """ + if device_specific: + seed += self.global_rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + return seed + + @override + def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: + assert all(isinstance(x, nn.Module) for x in models) + return models + + @override + def prepare_train_dataloader(self, train_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: + return train_dataloader + + @override + def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: + return val_dataloader + + @override + def compile(self, model: CompileT) -> CompileT: + result = torch.compile(model) + result = cast("CompileT", result) + return result + + @override + def backward(self, loss: torch.Tensor) -> None: + loss.backward() + + @override + def clip_grad_norm_( + self, + parameters: Iterable[torch.Tensor] | torch.Tensor, + max_norm: float, + norm_type: float = 2, + ) -> torch.Tensor: + result = torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) + return result + + @override + def clip_grad_value_( + self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float + ) -> None: + torch.nn.utils.clip_grad_value_(parameters, clip_value) + + @override + def save_state(self, path: str | Path) -> None: + pass + + @override + def load_state(self, path: str | Path) -> None: + pass + + @override + def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: + for _ in tqdm( + range(num_batches), desc=f"rank {self.global_rank}: skip_first_batches({num_batches=})" + ): + next(dataloader_iterator) + + @property + @override + def sync_gradients(self) -> bool: + return True + + @property + @override + def is_main_process(self) -> bool: + result = self.global_rank == 0 + return result + + @property + @override + def is_local_main_process(self) -> bool: + result = self.local_rank == 0 + return result + + @property + @override + def is_last_process(self) -> bool: + result = self.global_rank == self.world_size - 1 + return result + + @property + @override + def is_local_last_process(self) -> bool: + result = self.local_rank == self.local_world_size - 1 + return result + + @property + @override + def local_rank(self) -> int: + return 0 + + @property + @override + def global_rank(self) -> int: + return 0 + + @property + @override + def local_world_size(self) -> int: + return 1 + + @property + @override + def world_size(self) -> int: + return 1 + + @property + @override + def device(self) -> torch.device: + return self._device + + @property + @override + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def master_addr(self) -> str | None: + return None + + @property + def master_port(self) -> int | None: + return None diff --git a/tests/_test_utils/torch_dist/dist_utils.py b/tests/_test_utils/torch_dist/dist_utils.py index c7407b0188..f7160cf288 100644 --- a/tests/_test_utils/torch_dist/dist_utils.py +++ b/tests/_test_utils/torch_dist/dist_utils.py @@ -34,6 +34,11 @@ def init_process(rank, size, job=None, backend="gloo", port=None): """Initialize the distributed environment.""" os.environ["MASTER_ADDR"] = "localhost" + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(size) + os.environ["LOCAL_WORLD_SIZE"] = str(size) + os.environ["WANDB_DISABLED"] = "true" port = str(get_free_port()) if port is None else str(port) diff --git a/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml new file mode 100644 index 0000000000..1d8fac655f --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml @@ -0,0 +1,108 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml b/tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml b/tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..f0c852eec9 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml @@ -0,0 +1,12 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml b/tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml b/tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..0a5eafcfff --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml @@ -0,0 +1,32 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 100 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_outpt_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml b/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml new file mode 100644 index 0000000000..046ff51f65 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: conversation +val_dataset_name: train +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml b/tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json b/tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..02ee80b619 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json @@ -0,0 +1,16 @@ +{ + "bos_token": { + "content": "<|begin_of_text|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "<|eot_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json b/tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json new file mode 100644 index 0000000000..83592e2494 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json @@ -0,0 +1,212 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "Split", + "pattern": { + "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + }, + "behavior": "Isolated", + "invert": false + }, + { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": false + } + ] + }, + "post_processor": { + "type": "Sequence", + "processors": [ + { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 1 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + } + ], + "special_tokens": { + "<|begin_of_text|>": { + "id": "<|begin_of_text|>", + "ids": [ + 100 + ], + "tokens": [ + "<|begin_of_text|>" + ] + } + } + } + ] + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "!": 0, + "\"": 1, + "#": 2, + "$": 3, + "%": 4, + "&": 5, + "'": 6, + "(": 7, + ")": 8, + "*": 9, + "+": 10, + ",": 11, + "-": 12, + ".": 13, + "/": 14, + "0": 15, + "1": 16, + "2": 17, + "3": 18, + "4": 19, + "5": 20, + "6": 21, + "7": 22, + "8": 23, + "9": 24, + ":": 25, + ";": 26, + "<": 27, + "=": 28, + ">": 29, + "?": 30, + "@": 31, + "A": 32, + "B": 33, + "C": 34, + "D": 35, + "E": 36, + "F": 37, + "G": 38, + "H": 39, + "I": 40, + "J": 41, + "K": 42, + "L": 43, + "M": 44, + "N": 45, + "O": 46, + "P": 47, + "Q": 48, + "R": 49, + "S": 50, + "T": 51, + "U": 52, + "V": 53, + "W": 54, + "X": 55, + "Y": 56, + "Z": 57, + "[": 58, + "\\": 59, + "]": 60, + "^": 61, + "_": 62, + "`": 63, + "a": 64, + "b": 65, + "c": 66, + "d": 67, + "e": 68, + "f": 69, + "g": 70, + "h": 71, + "i": 72, + "j": 73, + "k": 74, + "l": 75, + "m": 76, + "n": 77, + "o": 78, + "p": 79, + "q": 80, + "r": 81, + "s": 82, + "t": 83, + "u": 84, + "v": 85, + "w": 86, + "x": 87, + "y": 88, + "z": 89, + "{": 90, + "|": 91, + "}": 92, + "~": 93, + "¡": 94, + "¢": 95, + "£": 96, + "¤": 97, + "¥": 98, + "¦": 99, + "<|begin_of_text|>": 100, + "<|eot_id|>": 101 + }, + "merges": [] + } +} diff --git a/tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json b/tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..754d9e8db5 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json @@ -0,0 +1,13 @@ +{ + "bos_token": "<|begin_of_text|>", + "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n", + "clean_up_tokenization_spaces": true, + "eos_token": "<|eot_id|>", + "extra_special_tokens": {}, + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 131072, + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py b/tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py new file mode 100644 index 0000000000..aedcae4ab2 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script was used to truncate the tokenizer.json file from Llama 3.1 8B model +to keep only the top 100 most common tokens. +""" + +import json + +# Path to your original and new tokenizer.json +in_path = "./tokenizer.json" +out_path = "./tokenizer_truncated.json" + +# How many top tokens to keep +NUM_TO_KEEP = 100 + +with open(in_path, encoding="utf-8") as f: + tokenizer_data = json.load(f) + +# Get and sort the original vocab by index (frequency proxy) +orig_vocab = tokenizer_data["model"]["vocab"] + +# Sort tokens by their original index (lowest index = assumed most common/important) +sorted_tokens = sorted(orig_vocab.items(), key=lambda item: item[1]) + +# Keep the top N tokens +tokens_to_keep = [tok for tok, idx in sorted_tokens[:NUM_TO_KEEP]] + +# Re-index the selected tokens: 0..N-1 +small_vocab = {tok: i for i, tok in enumerate(tokens_to_keep)} +tokenizer_data["model"]["vocab"] = small_vocab + +# Update vocab size +if "vocab_size" in tokenizer_data["model"]: + tokenizer_data["model"]["vocab_size"] = len(small_vocab) + +# Optionally remove merges if present and unneeded (mostly for BPE/WordPiece) +if "merges" in tokenizer_data["model"]: + tokenizer_data["model"]["merges"] = [] + +# Remove added_tokens if not needed +if "added_tokens" in tokenizer_data: + tokenizer_data["added_tokens"] = [] + +# Write out the truncated tokenizer.json +with open(out_path, "w", encoding="utf-8") as f: + json.dump(tokenizer_data, f, indent=2, ensure_ascii=False) + +print(f"Truncated tokenizer saved to: {out_path}") diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py new file mode 100644 index 0000000000..096de4de3c --- /dev/null +++ b/tests/experimental/torch/_compress/test_compress.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +import shutil +from functools import partial +from pathlib import Path + +import pytest +import torch +from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job +from datasets import Dataset, DatasetDict +from puzzle_tools.hydra_utils import register_hydra_resolvers +from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase + +from modelopt.torch._compress import compress +from modelopt.torch._compress.runtime import NativeDdpRuntime + + +@pytest.fixture +def project_root_path(request: pytest.FixtureRequest) -> Path: + return Path(request.config.rootpath) + + +# The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) +# using a one-click command. +# +# Note: Bypass is disabled now in the test. + +# How to run this test (currently only supported internally at Nvidia). +# +# Have both modelopt and puzzle source code in the same directory: +# /workspace/modelopt +# /workspace/puzzletron +# +# submit_job --partition interactive --time 0 \ +# --image gitlab-master.nvidia.com/deci/puzzletron:trtllm_main \ +# --workdir $MODELOPT SRC DIRECTORY --interactive --gpu 1 +# +# pip install mip +# pip install lru-dict +# +# export PYTHONPATH=$PYTHONPATH:/workspace/puzzletron/v1 +# +# pytest -s -v ./tests/experimental/torch/_compress/test_compress.py::test_compress -o addopts="" + + +def test_compress(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_compress_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): + register_hydra_resolvers() + + puzzle_dir = tmp_path + dataset_path = puzzle_dir / "dummy_dataset" + hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" + + _runtime = NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) + + with _runtime as runtime: + # + # Test setup + # + if rank == 0: + # Setup puzzle_dir and dataset + _setup_puzzle_dir(puzzle_dir) + _save_dummy_dataset(dataset_path) + + # + # Step 1: Create and save a teacher model to compress + # This mimics the normal pipeline where we start with a Llama model + # + tokenizer_path = ( + project_root_path / "tests/experimental/torch/_compress/resources/tokenizer" + ) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + # Create a small Llama model (not DeciLM) to match the normal conversion pipeline + hf_ckpt_teacher_dir = "ckpts/teacher" + llama_checkpoint_path = puzzle_dir / hf_ckpt_teacher_dir + _create_and_save_small_llama_model( + llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer + ) + + # Use the full conversion pipeline (matches normal usage) + convert_llama3_to_decilm( + input_dir=llama_checkpoint_path, + output_dir=llama_checkpoint_path, + ) + runtime.wait_for_everyone() + + # Compress the model using a one-click approach + compress.compress( + str(hydra_config_dir), "Llama-3_1-8B", str(puzzle_dir), str(dataset_path), runtime + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step 1 + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + # assertions for the build_library_and_stats step 4 + + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() + + # assertions for the scoring step 5 + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + + assert solution_0_filepath.exists() + + # assertions for the mip_and_realize_models step 6 + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) + + assert solution_0_ckpt_config_path.exists() + assert ( + puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json" + ).exists() + + runtime.wait_for_everyone() + + print("PYTEST SUMMARY: test_compress_model() test has finished successfully") + + +def _create_and_save_small_llama_model( + output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase +): + """ + Create and save a small Llama model for testing the conversion pipeline. + This mimics having a real Llama checkpoint that needs to be converted. + """ + os.makedirs(output_path, exist_ok=True) + + # Create a minimal Llama config (small for testing) + # Note: intermediate_size must be divisible by 256 per DeciLM config requirements + # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility + llama_config = LlamaConfig( + vocab_size=vocab_size, + hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) + intermediate_size=512, # Must be divisible by 256 + num_hidden_layers=2, + num_attention_heads=32, # Matches original test + num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) + max_position_embeddings=512, + rms_norm_eps=1e-5, + rope_theta=10000.0, + attention_bias=False, + hidden_act="silu", + tie_word_embeddings=False, + ) + + # Create and save the Llama model + model = LlamaForCausalLM(llama_config) + model.to(dtype=torch.bfloat16).save_pretrained(output_path) + + # Save tokenizer + tokenizer.save_pretrained(output_path) + + # Save config + llama_config.save_pretrained(output_path) + + +def _setup_puzzle_dir(puzzle_dir: str): + if Path(puzzle_dir).exists(): + shutil.rmtree(puzzle_dir) + Path(puzzle_dir).mkdir(parents=True, exist_ok=True) + + +def _save_dummy_dataset(dataset_path: str): + # dummy sample + sample = [ + {"role": "user", "content": "please cite Lorem Ipsum?"}, + { + "role": "assistant", + "content": ( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " + "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " + "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " + "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " + "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " + "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " + "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " + "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " + "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " + "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " + "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " + "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " + "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " + "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " + "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " + "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " + "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " + "Donec mollis convallis massa quis iaculis." + ), + }, + ] + + # Prepare train and val splits with sample repeated, 2500 samples are for + # 128 samples with block-size 8192 and LLama3 tokenizer + data = [{"conversation": sample}] * 2500 + + # For train-val splits + data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) + data_dict.save_to_disk(dataset_path) From cef3655dfe3f6a788ed7aee51ee4f2773477b25b Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 29 Oct 2025 17:47:44 +0100 Subject: [PATCH 04/62] Add convert_llama3_config_to_decilm_config + unit test (#465) Signed-off-by: Daniel Korzekwa Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../converters/convert_llama3_to_decilm.py | 152 ++++++++++++++++++ setup.py | 2 + .../experimental/torch/_compress/conftest.py | 120 ++++++++++++++ ..._convert_llama3_config_to_decilm_config.py | 53 ++++++ .../torch/_compress/test_compress.py | 115 ++----------- 5 files changed, 340 insertions(+), 102 deletions(-) create mode 100644 modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py create mode 100644 tests/experimental/torch/_compress/conftest.py create mode 100644 tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py diff --git a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py new file mode 100644 index 0000000000..4b65eeada5 --- /dev/null +++ b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert a Llama3 model to a DeciLM model.""" + +#!/usr/bin/env python3 +from pathlib import Path + +from fire import Fire +from puzzle_tools.checkpoint_utils import copy_tokenizer +from puzzle_tools.checkpoint_utils_hf import copy_deci_lm_hf_code +from puzzle_tools.conversion_utils import convert_model_weights_to_decilm +from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from transformers import LlamaConfig + +""" +example: + +python -m scripts.hf.convert_llama3_to_decilm \ + --input_dir .../meta-llama/Meta-Llama-3.1-8B-Instruct \ + --output_dir .../meta-llama/Meta-Llama-3.1-8B-Instruct--deci-hf/ +""" + + +def convert_llama3_config_to_decilm_config(config: LlamaConfig) -> DeciLMConfig: + """Convert Llama3 config to DeciLM config format.""" + print("\n=== Converting Llama3 Config to DeciLM Config ===") + + # Get dtype from config - check both dtype and torch_dtype + # Prefer dtype if it's set (not None), otherwise fall back to torch_dtype + dtype = getattr(config, "dtype", None) + if dtype is None: + dtype = getattr(config, "torch_dtype", None) + + # Convert torch.dtype to string if needed (for JSON serialization) + if dtype is not None and hasattr(dtype, "__module__") and "torch" in dtype.__module__: + dtype = str(dtype).replace("torch.", "") + + # Track which global values will be removed (moved to per-layer configs) + print("\n📝 Converting global values to per-layer block_configs:") + print( + f" - intermediate_size: {config.intermediate_size} → block_configs[*].ffn.intermediate_size" + ) + print( + f" - num_key_value_heads: {config.num_key_value_heads} → block_configs[*].attention.n_heads_in_group (derived)" + ) + print(f" - hidden_act: {config.hidden_act} → block_configs[*].ffn.hidden_act") + print( + f" - sliding_window: {getattr(config, 'sliding_window', None)} → block_configs[*].attention.window_length" + ) + + # Create block configs for each layer + block_configs = [] + for i in range(config.num_hidden_layers): + # Configure attention + attention_config = { + "no_op": False, + "replace_with_linear": False, + "sparsify": None, + "n_heads_in_group": config.num_attention_heads // config.num_key_value_heads, + "window_length": None, # Llama3 doesn't use sliding window by default + "num_sink_tokens": None, # Llama3 doesn't use sink attention + "use_prefill_window_in_sink_attention": False, + "unshifted_sink": False, + "mamba": None, + "llama4": None, # No Llama4 specific attention for Llama3 + } + + # Configure FFN + ffn_config = { + "no_op": False, + "replace_with_linear": False, + "sparsify": None, + "intermediate_size": config.intermediate_size, + "gated": True, # Llama3 uses SwiGLU + "hidden_act": config.hidden_act, + "moe": None, # Llama3 doesn't use MoE + } + + block_configs.append({"attention": attention_config, "ffn": ffn_config}) + + # Create DeciLM config + decilm_config = DeciLMConfig( + block_configs=block_configs, + hidden_size=config.hidden_size, + max_position_embeddings=config.max_position_embeddings, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=config.num_hidden_layers, + tie_word_embeddings=config.tie_word_embeddings, + vocab_size=config.vocab_size, + rms_norm_eps=config.rms_norm_eps, + attention_bias=config.attention_bias, + o_proj_bias=config.attention_bias, # llama3 bias defined by attention_bias + rope_theta=config.rope_theta, + rope_scaling=config.rope_scaling, + position_embedding_type="rope", # Llama3 uses standard RoPE + architectures=["DeciLMForCausalLM"], + auto_map={ + "AutoConfig": "configuration_decilm.DeciLMConfig", + "AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM", + }, + eos_token_id=config.eos_token_id, + bos_token_id=config.bos_token_id, + pad_token_id=config.pad_token_id, + head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), + dtype=dtype, + ) + + print(f"\n✓ Created DeciLM config with {len(block_configs)} layers") + print( + "✓ Global per-layer keys (intermediate_size, num_key_value_heads, hidden_act, sliding_window)" + ) + print(" will be removed from saved config and are only in block_configs") + + return decilm_config + + +def convert_configs_in_dirs(input_dir, output_dir): + """Convert the config of a Llama3 model to a DeciLM model.""" + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + input_config_path = input_dir / "config.json" + config = LlamaConfig.from_pretrained(input_config_path) + decilm_config = convert_llama3_config_to_decilm_config(config) + decilm_config.save_pretrained(output_dir) + + +def convert_llama3_to_decilm(input_dir, output_dir): + """Convert a Llama3 model to a DeciLM model.""" + convert_configs_in_dirs(input_dir, output_dir) + copy_tokenizer(input_dir, output_dir) + convert_model_weights_to_decilm(input_dir, output_dir) + copy_deci_lm_hf_code(output_dir) + + +if __name__ == "__main__": + Fire(convert_llama3_to_decilm) diff --git a/setup.py b/setup.py index 67bf114ae1..cfadd51705 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,8 @@ "setuptools>=80", "setuptools-scm>=8", ], + # Dependedencies for modelopt.torch._compress subpackage + "compress": ["fire"], } # create "compound" optional dependencies diff --git a/tests/experimental/torch/_compress/conftest.py b/tests/experimental/torch/_compress/conftest.py new file mode 100644 index 0000000000..4dedf5363b --- /dev/null +++ b/tests/experimental/torch/_compress/conftest.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +from pathlib import Path + +import pytest +import torch +from datasets import Dataset, DatasetDict +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase + + +@pytest.fixture +def project_root_path(request: pytest.FixtureRequest) -> Path: + """Fixture providing the project root path for tests.""" + return Path(request.config.rootpath) + + +def create_and_save_small_llama_model( + output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase +): + """ + Create and save a small Llama model for testing the conversion pipeline. + This mimics having a real Llama checkpoint that needs to be converted. + """ + os.makedirs(output_path, exist_ok=True) + + # Create a minimal Llama config (small for testing) + # Note: intermediate_size must be divisible by 256 per DeciLM config requirements + # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility + llama_config = LlamaConfig( + vocab_size=vocab_size, + hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) + intermediate_size=512, # Must be divisible by 256 + num_hidden_layers=2, + num_attention_heads=32, # Matches original test + num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) + max_position_embeddings=512, + rms_norm_eps=1e-5, + rope_theta=10000.0, + attention_bias=False, + hidden_act="silu", + tie_word_embeddings=False, + ) + + # Create and save the Llama model + model = LlamaForCausalLM(llama_config) + model.to(dtype=torch.bfloat16).save_pretrained(output_path) + + # Save tokenizer + tokenizer.save_pretrained(output_path) + + # Save config + llama_config.save_pretrained(output_path) + + +def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: + """ + Create a tokenizer for the Llama model. + """ + tokenizer_path = project_root_path / "tests/experimental/torch/_compress/resources/tokenizer" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + return tokenizer + + +def setup_puzzle_dir(puzzle_dir: str): + if Path(puzzle_dir).exists(): + shutil.rmtree(puzzle_dir) + Path(puzzle_dir).mkdir(parents=True, exist_ok=True) + + +def save_dummy_dataset(dataset_path: str): + # dummy sample + sample = [ + {"role": "user", "content": "please cite Lorem Ipsum?"}, + { + "role": "assistant", + "content": ( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " + "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " + "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " + "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " + "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " + "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " + "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " + "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " + "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " + "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " + "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " + "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " + "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " + "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " + "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " + "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " + "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " + "Donec mollis convallis massa quis iaculis." + ), + }, + ] + + # Prepare train and val splits with sample repeated, 2500 samples are for + # 128 samples with block-size 8192 and LLama3 tokenizer + data = [{"conversation": sample}] * 2500 + + # For train-val splits + data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) + data_dict.save_to_disk(dataset_path) diff --git a/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py b/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py new file mode 100644 index 0000000000..a1d897ceb5 --- /dev/null +++ b/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path + +from experimental.torch._compress.conftest import ( + create_and_save_small_llama_model, + create_tokenizer, +) + +from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( + convert_llama3_to_decilm, +) + + +def test_convert_llama3_config_to_decilm_config(project_root_path: Path, tmp_path: Path): + tokenizer = create_tokenizer(project_root_path) + llama_checkpoint_path = tmp_path / "llama_checkpoint" + create_and_save_small_llama_model( + llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer + ) + + # Convert the Llama model to a DeciLM model + decilm_checkpoint_path = tmp_path / "decilm_checkpoint" + convert_llama3_to_decilm( + input_dir=llama_checkpoint_path, + output_dir=decilm_checkpoint_path, + ) + + # Assert that the converted config has the correct number of block_configs + config_path = decilm_checkpoint_path / "config.json" + assert config_path.exists(), f"Config file not found at {config_path}" + + with open(config_path) as f: + decilm_config = json.load(f) + + # Verify block_configs exists and has the correct length + assert "block_configs" in decilm_config, "block_configs not found in converted config" + actual_num_block_configs = len(decilm_config["block_configs"]) + assert actual_num_block_configs == 2 diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 096de4de3c..018b78e1a5 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -15,27 +15,23 @@ import datetime import os -import shutil from functools import partial from pathlib import Path -import pytest import torch -from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job -from datasets import Dataset, DatasetDict +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from experimental.torch._compress.conftest import ( + create_and_save_small_llama_model, + create_tokenizer, + save_dummy_dataset, + setup_puzzle_dir, +) from puzzle_tools.hydra_utils import register_hydra_resolvers from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm -from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase from modelopt.torch._compress import compress from modelopt.torch._compress.runtime import NativeDdpRuntime - -@pytest.fixture -def project_root_path(request: pytest.FixtureRequest) -> Path: - return Path(request.config.rootpath) - - # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) # using a one-click command. # @@ -73,6 +69,7 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran puzzle_dir = tmp_path dataset_path = puzzle_dir / "dummy_dataset" hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B" _runtime = NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) @@ -84,23 +81,19 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran # if rank == 0: # Setup puzzle_dir and dataset - _setup_puzzle_dir(puzzle_dir) - _save_dummy_dataset(dataset_path) + setup_puzzle_dir(puzzle_dir) + save_dummy_dataset(dataset_path) # # Step 1: Create and save a teacher model to compress # This mimics the normal pipeline where we start with a Llama model # - tokenizer_path = ( - project_root_path / "tests/experimental/torch/_compress/resources/tokenizer" - ) - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) # Create a small Llama model (not DeciLM) to match the normal conversion pipeline + tokenizer = create_tokenizer(project_root_path) hf_ckpt_teacher_dir = "ckpts/teacher" llama_checkpoint_path = puzzle_dir / hf_ckpt_teacher_dir - _create_and_save_small_llama_model( + create_and_save_small_llama_model( llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer ) @@ -113,7 +106,7 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran # Compress the model using a one-click approach compress.compress( - str(hydra_config_dir), "Llama-3_1-8B", str(puzzle_dir), str(dataset_path), runtime + str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path), runtime ) # @@ -156,85 +149,3 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran runtime.wait_for_everyone() print("PYTEST SUMMARY: test_compress_model() test has finished successfully") - - -def _create_and_save_small_llama_model( - output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase -): - """ - Create and save a small Llama model for testing the conversion pipeline. - This mimics having a real Llama checkpoint that needs to be converted. - """ - os.makedirs(output_path, exist_ok=True) - - # Create a minimal Llama config (small for testing) - # Note: intermediate_size must be divisible by 256 per DeciLM config requirements - # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility - llama_config = LlamaConfig( - vocab_size=vocab_size, - hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) - intermediate_size=512, # Must be divisible by 256 - num_hidden_layers=2, - num_attention_heads=32, # Matches original test - num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) - max_position_embeddings=512, - rms_norm_eps=1e-5, - rope_theta=10000.0, - attention_bias=False, - hidden_act="silu", - tie_word_embeddings=False, - ) - - # Create and save the Llama model - model = LlamaForCausalLM(llama_config) - model.to(dtype=torch.bfloat16).save_pretrained(output_path) - - # Save tokenizer - tokenizer.save_pretrained(output_path) - - # Save config - llama_config.save_pretrained(output_path) - - -def _setup_puzzle_dir(puzzle_dir: str): - if Path(puzzle_dir).exists(): - shutil.rmtree(puzzle_dir) - Path(puzzle_dir).mkdir(parents=True, exist_ok=True) - - -def _save_dummy_dataset(dataset_path: str): - # dummy sample - sample = [ - {"role": "user", "content": "please cite Lorem Ipsum?"}, - { - "role": "assistant", - "content": ( - "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " - "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " - "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " - "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " - "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " - "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " - "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " - "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " - "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " - "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " - "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " - "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " - "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " - "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " - "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " - "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " - "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " - "Donec mollis convallis massa quis iaculis." - ), - }, - ] - - # Prepare train and val splits with sample repeated, 2500 samples are for - # 128 samples with block-size 8192 and LLama3 tokenizer - data = [{"conversation": sample}] * 2500 - - # For train-val splits - data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) - data_dict.save_to_disk(dataset_path) From 002b8b522cacef045d3ce714aac4c61b8af657a3 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 31 Oct 2025 19:27:33 +0100 Subject: [PATCH 05/62] Implement nas.convert() api for the compress algorithm (#482) Signed-off-by: Daniel Korzekwa Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/_compress/compress.py | 3 +- .../converters/convert_llama3_to_decilm.py | 3 +- modelopt/torch/_compress/hydra.py | 54 ++++++ .../nas/plugins/compress_nas_plugin.py | 167 ++++++++++++++++++ setup.py | 6 +- .../torch/_compress/compress_test_utils.py | 119 +++++++++++++ .../experimental/torch/_compress/conftest.py | 96 ---------- ..._convert_llama3_config_to_decilm_config.py | 2 +- .../_compress/nas/plugins/test_nas_convert.py | 114 ++++++++++++ .../torch/_compress/test_compress.py | 20 ++- 10 files changed, 474 insertions(+), 110 deletions(-) create mode 100644 modelopt/torch/_compress/hydra.py create mode 100644 modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py create mode 100644 tests/experimental/torch/_compress/compress_test_utils.py create mode 100644 tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 265fd5eeb2..455cf3f8ec 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -28,8 +28,7 @@ from omegaconf import DictConfig from puzzle_tools.runtime import IRuntime -# TODO Move initialize_hydra_config_for_dir from tests to main -from tests.utils.test_utils import initialize_hydra_config_for_dir +from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir def compress( diff --git a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py index 4b65eeada5..d17e7ef74b 100644 --- a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py +++ b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py @@ -19,6 +19,7 @@ #!/usr/bin/env python3 from pathlib import Path +import torch from fire import Fire from puzzle_tools.checkpoint_utils import copy_tokenizer from puzzle_tools.checkpoint_utils_hf import copy_deci_lm_hf_code @@ -46,7 +47,7 @@ def convert_llama3_config_to_decilm_config(config: LlamaConfig) -> DeciLMConfig: dtype = getattr(config, "torch_dtype", None) # Convert torch.dtype to string if needed (for JSON serialization) - if dtype is not None and hasattr(dtype, "__module__") and "torch" in dtype.__module__: + if dtype is not None and isinstance(dtype, torch.dtype): dtype = str(dtype).replace("torch.", "") # Track which global values will be removed (moved to per-layer configs) diff --git a/modelopt/torch/_compress/hydra.py b/modelopt/torch/_compress/hydra.py new file mode 100644 index 0000000000..8c36d309e4 --- /dev/null +++ b/modelopt/torch/_compress/hydra.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra import compose, initialize, initialize_config_dir +from omegaconf import DictConfig, OmegaConf + +""" +Utilities for hydra config initialization. +""" + + +def initialize_hydra_config_for_dir( + config_dir: str, config_name: str, overrides: list[str] +) -> DictConfig: + """Initialize a hydra config from an absolute path for a config directory + + Args: + config_dir (str): + config_name (str): + overrides (List[str]): + + Returns: + DictConfig: + """ + + with initialize_config_dir(version_base=None, config_dir=config_dir): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args + + +def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig: + with initialize(version_base=None, config_path=config_path): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py new file mode 100644 index 0000000000..9c4b5acadd --- /dev/null +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compress NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). +""" + +import datetime +from pathlib import Path + +import pruning_ckpts +import score_pruning_activations +import torch +from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm +from torch import nn + +from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch.nas.conversion import NASModeRegistry +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ConvertReturnType, + MetadataDict, + ModeDescriptor, + RestoreEntrypoint, +) +from modelopt.torch.opt.searcher import BaseSearcher + + +class CompressModel(nn.Module): + pass # No model implementation is needed for the compress mode + + +class CompressConfig(ModeloptBaseConfig): + """Configuration for Compress NAS algorithm.""" + + # Input model path to compress in the HF format + input_model_path: str = ModeloptField( + default="", + title="", + description="", + ) + + # Hydra config directory containing the search space definition + hydra_config_dir: str = ModeloptField( + default="", + title="", + description="", + ) + + # Hydra config name containing the search space definition + hydra_config_name: str = ModeloptField( + default="", + title="", + description="", + ) + + # Directory to save the compressed model and intermediate results + puzzle_dir: str = ModeloptField( + default="", + title="", + description="", + ) + + # Dataset path to use for scoring in prunining and NAS search + dataset_path: str = ModeloptField( + default="", + title="", + description="", + ) + + +def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertReturnType: + """1. Convert the model from HF format to DeciLM format. + 2. Score the pruning activations. + 3. Prune the model and save pruned checkpoints + + The output of this step will be used by mnt.search() to perform the NAS search. + """ + runtime = NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=config.hydra_config_dir, + config_name=config.hydra_config_name, + overrides=[ + f"puzzle_dir={config.puzzle_dir}", + f"dataset_path={config.dataset_path}", + ], + ) + + # Convert Llama3 model to DeciLM model + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + convert_llama3_to_decilm( + input_dir=config.input_model_path, + output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, + ) + + # Score_pruning_activations (distributed processing) + score_pruning_activations.launch_score_activations(hydra_cfg, runtime) + + # Prune the model and save pruned checkpoints + if runtime.global_rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + runtime.wait_for_everyone() + + return model, {} + + +def restore_compress_model( + model: nn.Module, config: CompressConfig, metadata: MetadataDict +) -> nn.Module: + """Restore is not needed for the compress mode as we are not saving any model state""" + return model + + +@NASModeRegistry.register_mode +class CompressDescriptor(ModeDescriptor): + """Descriptor for the Compress mode.""" + + @property + def name(self) -> str: + """String identifier for this mode.""" + return "compress" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Configuration class for this mode.""" + return CompressConfig + + @property + def search_algorithm(self) -> type[BaseSearcher]: + """Return the associated searcher implementation.""" + raise NotImplementedError("Compress mode does not have a search algorithm yet.") + + @property + def convert(self) -> ConvertEntrypoint: + """Entrypoint to convert a model.""" + return convert_compress_model + + @property + def restore(self) -> RestoreEntrypoint: + """Entrypoint to restore a model.""" + return restore_compress_model + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode. + For now, this will be a no-op as there is no modelopt's concept of search space defined + for the compress algorithm. + """ + return "export_nas" diff --git a/setup.py b/setup.py index cfadd51705..568131f486 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,11 @@ "setuptools-scm>=8", ], # Dependedencies for modelopt.torch._compress subpackage - "compress": ["fire"], + "compress": [ + "fire", + "hydra-core==1.3.2", + "omegaconf==2.3.0", + ], } # create "compound" optional dependencies diff --git a/tests/experimental/torch/_compress/compress_test_utils.py b/tests/experimental/torch/_compress/compress_test_utils.py new file mode 100644 index 0000000000..21ca622dae --- /dev/null +++ b/tests/experimental/torch/_compress/compress_test_utils.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +from pathlib import Path + +import torch +from datasets import Dataset, DatasetDict +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase + + +def create_and_save_small_llama_model( + output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase +): + """ + Create and save a small Llama model for testing the conversion pipeline. + This mimics having a real Llama checkpoint that needs to be converted. + """ + os.makedirs(output_path, exist_ok=True) + + # Create a minimal Llama config (small for testing) + # Note: intermediate_size must be divisible by 256 per DeciLM config requirements + # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility + llama_config = LlamaConfig( + vocab_size=vocab_size, + hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) + intermediate_size=512, # Must be divisible by 256 + num_hidden_layers=2, + num_attention_heads=32, # Matches original test + num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) + max_position_embeddings=512, + rms_norm_eps=1e-5, + rope_theta=10000.0, + attention_bias=False, + hidden_act="silu", + tie_word_embeddings=False, + ) + + # Create and save the Llama model + model = LlamaForCausalLM(llama_config) + model.to(dtype=torch.bfloat16).save_pretrained(output_path) + + # Save tokenizer + tokenizer.save_pretrained(output_path) + + # Save config + llama_config.save_pretrained(output_path) + + +def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: + """ + Create a tokenizer for the Llama model. + """ + tokenizer_path = project_root_path / "tests/experimental/torch/_compress/resources/tokenizer" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + return tokenizer + + +def setup_puzzle_dir(puzzle_dir: str): + """ + Setup puzzle directory by removing existing directory and creating a new one. + """ + if Path(puzzle_dir).exists(): + shutil.rmtree(puzzle_dir) + Path(puzzle_dir).mkdir(parents=True, exist_ok=True) + + +def save_dummy_dataset(dataset_path: str): + """ + Save a dummy dataset for testing purposes. + """ + # dummy sample + sample = [ + {"role": "user", "content": "please cite Lorem Ipsum?"}, + { + "role": "assistant", + "content": ( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " + "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " + "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " + "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " + "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " + "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " + "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " + "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " + "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " + "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " + "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " + "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " + "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " + "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " + "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " + "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " + "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " + "Donec mollis convallis massa quis iaculis." + ), + }, + ] + + # Prepare train and val splits with sample repeated, 2500 samples are for + # 128 samples with block-size 8192 and LLama3 tokenizer + data = [{"conversation": sample}] * 2500 + + # For train-val splits + data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) + data_dict.save_to_disk(dataset_path) diff --git a/tests/experimental/torch/_compress/conftest.py b/tests/experimental/torch/_compress/conftest.py index 4dedf5363b..cae1bfbca5 100644 --- a/tests/experimental/torch/_compress/conftest.py +++ b/tests/experimental/torch/_compress/conftest.py @@ -13,108 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import shutil from pathlib import Path import pytest -import torch -from datasets import Dataset, DatasetDict -from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase @pytest.fixture def project_root_path(request: pytest.FixtureRequest) -> Path: """Fixture providing the project root path for tests.""" return Path(request.config.rootpath) - - -def create_and_save_small_llama_model( - output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase -): - """ - Create and save a small Llama model for testing the conversion pipeline. - This mimics having a real Llama checkpoint that needs to be converted. - """ - os.makedirs(output_path, exist_ok=True) - - # Create a minimal Llama config (small for testing) - # Note: intermediate_size must be divisible by 256 per DeciLM config requirements - # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility - llama_config = LlamaConfig( - vocab_size=vocab_size, - hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) - intermediate_size=512, # Must be divisible by 256 - num_hidden_layers=2, - num_attention_heads=32, # Matches original test - num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) - max_position_embeddings=512, - rms_norm_eps=1e-5, - rope_theta=10000.0, - attention_bias=False, - hidden_act="silu", - tie_word_embeddings=False, - ) - - # Create and save the Llama model - model = LlamaForCausalLM(llama_config) - model.to(dtype=torch.bfloat16).save_pretrained(output_path) - - # Save tokenizer - tokenizer.save_pretrained(output_path) - - # Save config - llama_config.save_pretrained(output_path) - - -def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: - """ - Create a tokenizer for the Llama model. - """ - tokenizer_path = project_root_path / "tests/experimental/torch/_compress/resources/tokenizer" - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - return tokenizer - - -def setup_puzzle_dir(puzzle_dir: str): - if Path(puzzle_dir).exists(): - shutil.rmtree(puzzle_dir) - Path(puzzle_dir).mkdir(parents=True, exist_ok=True) - - -def save_dummy_dataset(dataset_path: str): - # dummy sample - sample = [ - {"role": "user", "content": "please cite Lorem Ipsum?"}, - { - "role": "assistant", - "content": ( - "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " - "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " - "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " - "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " - "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " - "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " - "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " - "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " - "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " - "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " - "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " - "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " - "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " - "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " - "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " - "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " - "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " - "Donec mollis convallis massa quis iaculis." - ), - }, - ] - - # Prepare train and val splits with sample repeated, 2500 samples are for - # 128 samples with block-size 8192 and LLama3 tokenizer - data = [{"conversation": sample}] * 2500 - - # For train-val splits - data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) - data_dict.save_to_disk(dataset_path) diff --git a/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py b/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py index a1d897ceb5..1f0283b3e8 100644 --- a/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py +++ b/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py @@ -16,7 +16,7 @@ import json from pathlib import Path -from experimental.torch._compress.conftest import ( +from experimental.torch._compress.compress_test_utils import ( create_and_save_small_llama_model, create_tokenizer, ) diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py new file mode 100644 index 0000000000..ad85804678 --- /dev/null +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +from functools import partial +from pathlib import Path + +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from experimental.torch._compress.compress_test_utils import ( + create_and_save_small_llama_model, + create_tokenizer, + save_dummy_dataset, + setup_puzzle_dir, +) +from puzzle_tools.hydra_utils import register_hydra_resolvers + +import modelopt.torch.nas as mtn +from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel +from modelopt.torch._compress.runtime import NativeDdpRuntime + + +# +# See tests/experimental/torch/_compress/test_compress.py for instructions on how to run this test +# TODO: Remove those instructions once this test runs automatically on CI +# +def test_nas_convert(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_convert_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_convert_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + # + # The inputs for the nas.convert() step. + # + puzzle_dir = tmp_path + llama_checkpoint_path = puzzle_dir / "ckpts/llama" + dataset_path = puzzle_dir / "dummy_dataset" + hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B" + + with NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) as runtime: + if rank == 0: + # Setup puzzle_dir and dataset + setup_puzzle_dir(puzzle_dir) + save_dummy_dataset(dataset_path) + + # Create a small Llama model + tokenizer = create_tokenizer(project_root_path) + create_and_save_small_llama_model( + llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer + ) + runtime.wait_for_everyone() + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + runtime.wait_for_everyone() + + print("PYTEST SUMMARY: test_nas_convert() test has finished successfully") diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 018b78e1a5..0bd116d161 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -20,7 +20,7 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from experimental.torch._compress.conftest import ( +from experimental.torch._compress.compress_test_utils import ( create_and_save_small_llama_model, create_tokenizer, save_dummy_dataset, @@ -66,16 +66,17 @@ def test_compress(project_root_path: Path, tmp_path: Path): def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): register_hydra_resolvers() + # + # The inputs for the compress() algorihm. + # puzzle_dir = tmp_path dataset_path = puzzle_dir / "dummy_dataset" hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" hydra_config_name = "Llama-3_1-8B" - _runtime = NativeDdpRuntime( + with NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) - - with _runtime as runtime: + ) as runtime: # # Test setup # @@ -91,8 +92,9 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran # Create a small Llama model (not DeciLM) to match the normal conversion pipeline tokenizer = create_tokenizer(project_root_path) - hf_ckpt_teacher_dir = "ckpts/teacher" - llama_checkpoint_path = puzzle_dir / hf_ckpt_teacher_dir + # TODO: change it to "ckpts/llama" once the conversion script is fixed + # Currently, the build replacement library step will fail with such a path. + llama_checkpoint_path = puzzle_dir / "ckpts/teacher" create_and_save_small_llama_model( llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer ) @@ -100,7 +102,7 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran # Use the full conversion pipeline (matches normal usage) convert_llama3_to_decilm( input_dir=llama_checkpoint_path, - output_dir=llama_checkpoint_path, + output_dir=puzzle_dir / "ckpts/teacher", ) runtime.wait_for_everyone() @@ -148,4 +150,4 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran runtime.wait_for_everyone() - print("PYTEST SUMMARY: test_compress_model() test has finished successfully") + print("PYTEST SUMMARY: test_compress_model() test has finished successfully") From 1c12fd8008d69a8e5ade1ca91b1e22f42232640d Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 3 Nov 2025 14:37:40 +0100 Subject: [PATCH 06/62] modelopt nas search() implementation for the compress algorithm (#490) Signed-off-by: Daniel Korzekwa Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../nas/plugins/compress_nas_plugin.py | 58 ++++++++- .../torch/_compress/compress_test_utils.py | 59 ++++++++++ .../_compress/nas/plugins/test_nas_convert.py | 35 +----- .../_compress/nas/plugins/test_nas_search.py | 110 ++++++++++++++++++ .../torch/_compress/test_compress.py | 50 ++------ 5 files changed, 239 insertions(+), 73 deletions(-) create mode 100644 tests/experimental/torch/_compress/nas/plugins/test_nas_search.py diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 9c4b5acadd..d821fbd029 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -20,12 +20,17 @@ import datetime from pathlib import Path +import build_library_and_stats +import mip_and_realize_models import pruning_ckpts import score_pruning_activations +import scoring import torch -from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm from torch import nn +from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( + convert_llama3_to_decilm, +) from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir from modelopt.torch._compress.runtime import NativeDdpRuntime from modelopt.torch.nas.conversion import NASModeRegistry @@ -37,7 +42,7 @@ ModeDescriptor, RestoreEntrypoint, ) -from modelopt.torch.opt.searcher import BaseSearcher +from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict class CompressModel(nn.Module): @@ -90,10 +95,19 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR The output of this step will be used by mnt.search() to perform the NAS search. """ + + # NativeDdpRuntime must be initialized/closed from outside of this function, so we are + # NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it. runtime = NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) + # Required for mtn.search() to read NAS configuration + model.hydra_config_dir = config.hydra_config_dir + model.hydra_config_name = config.hydra_config_name + model.puzzle_dir = config.puzzle_dir + model.dataset_path = config.dataset_path + # Load hydra config hydra_cfg = initialize_hydra_config_for_dir( config_dir=config.hydra_config_dir, @@ -146,7 +160,8 @@ def config_class(self) -> type[ModeloptBaseConfig]: @property def search_algorithm(self) -> type[BaseSearcher]: """Return the associated searcher implementation.""" - raise NotImplementedError("Compress mode does not have a search algorithm yet.") + + return CompressSearcher @property def convert(self) -> ConvertEntrypoint: @@ -165,3 +180,40 @@ def export_mode(self) -> str | None: for the compress algorithm. """ return "export_nas" + + +class CompressSearcher(BaseSearcher): + """Runs NAS search for the Compress mode.""" + + @property + def default_state_dict(self) -> SearchStateDict: + """Not needed for the compress mode as we are not saving any model state""" + return {} + + def run_search(self) -> None: + # NativeDdpRuntime must be initialized/closed from outside of this function, so we are + # NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it. + runtime = NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=self.model.hydra_config_dir, + config_name=self.model.hydra_config_name, + overrides=[ + f"puzzle_dir={self.model.puzzle_dir}", + f"dataset_path={self.model.dataset_path}", + ], + ) + + # Build_library_and_stats (single process) + if runtime.global_rank == 0: + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + runtime.wait_for_everyone() + + # Calc_one_block_scores (distributed processing) + scoring.launch_scoring(hydra_cfg, runtime) + + # mip_and_realize_models (distributed processing) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) diff --git a/tests/experimental/torch/_compress/compress_test_utils.py b/tests/experimental/torch/_compress/compress_test_utils.py index 21ca622dae..f0704f6c89 100644 --- a/tests/experimental/torch/_compress/compress_test_utils.py +++ b/tests/experimental/torch/_compress/compress_test_utils.py @@ -19,9 +19,68 @@ import torch from datasets import Dataset, DatasetDict +from puzzle_tools.hydra_utils import register_hydra_resolvers from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase +def setup_test_model_and_data( + project_root_path: Path, + tmp_path: Path, + rank: int, + runtime, +) -> tuple[ + Path, + Path, + Path, + Path, + str, +]: + """ + Setup the test model and data for the compress NAS search. + + Args: + project_root_path (Path): the root path of the project + tmp_path (Path): the temporary path to use for the test + rank (int): the rank of the process + runtime: the runtime to use for the test + + Returns: + tuple[Path, Path, Path, Path, str]: + the puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name + """ + + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + # The inputs for the nas.convert() step. + # + puzzle_dir = tmp_path + llama_checkpoint_path = puzzle_dir / "input_model/llama" + dataset_path = puzzle_dir / "dummy_dataset" + hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B" + + if rank == 0: + # Setup puzzle_dir and dataset + setup_puzzle_dir(puzzle_dir) + save_dummy_dataset(dataset_path) + + # Create a small Llama model + tokenizer = create_tokenizer(project_root_path) + create_and_save_small_llama_model( + llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer + ) + runtime.wait_for_everyone() + + return ( + puzzle_dir, + llama_checkpoint_path, + dataset_path, + hydra_config_dir, + hydra_config_name, + ) + + def create_and_save_small_llama_model( output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase ): diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py index ad85804678..7dc2d72285 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py @@ -20,13 +20,7 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from experimental.torch._compress.compress_test_utils import ( - create_and_save_small_llama_model, - create_tokenizer, - save_dummy_dataset, - setup_puzzle_dir, -) -from puzzle_tools.hydra_utils import register_hydra_resolvers +from experimental.torch._compress.compress_test_utils import setup_test_model_and_data import modelopt.torch.nas as mtn from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel @@ -48,32 +42,13 @@ def test_nas_convert(project_root_path: Path, tmp_path: Path): def _test_nas_convert_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - # Register Hydra custom resolvers (needed for config resolution) - register_hydra_resolvers() - - # - # The inputs for the nas.convert() step. - # - puzzle_dir = tmp_path - llama_checkpoint_path = puzzle_dir / "ckpts/llama" - dataset_path = puzzle_dir / "dummy_dataset" - hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B" - with NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) as runtime: - if rank == 0: - # Setup puzzle_dir and dataset - setup_puzzle_dir(puzzle_dir) - save_dummy_dataset(dataset_path) - - # Create a small Llama model - tokenizer = create_tokenizer(project_root_path) - create_and_save_small_llama_model( - llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer - ) - runtime.wait_for_everyone() + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( + setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + ) # # Run the mnt.convert() step diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py new file mode 100644 index 0000000000..04707d20f0 --- /dev/null +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# See tests/experimental/torch/_compress/test_compress.py for instructions on how to run this test +# TODO: Remove those instructions once this test runs automatically on CI +# +import datetime +from functools import partial +from pathlib import Path + +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from experimental.torch._compress.compress_test_utils import setup_test_model_and_data + +import modelopt.torch.nas as mtn +from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel +from modelopt.torch._compress.runtime import NativeDdpRuntime + + +def test_nas_search(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_search_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_search_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + with NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) as runtime: + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( + setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + ) + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Run the mnt.search() step + # + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + # + # Check assertions for mtn.search() step + # + if rank == 0: + # assertions for the build_library_and_stats step + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() + + # assertions for the scoring step + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + + assert solution_0_filepath.exists() + + # assertions for the mip_and_realize_models step + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) + + assert solution_0_ckpt_config_path.exists() + assert ( + puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json" + ).exists() + + runtime.wait_for_everyone() + + print("PYTEST SUMMARY: test_nas_search() test has finished successfully") diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 0bd116d161..3d5d6b666d 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -20,16 +20,12 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from experimental.torch._compress.compress_test_utils import ( - create_and_save_small_llama_model, - create_tokenizer, - save_dummy_dataset, - setup_puzzle_dir, -) -from puzzle_tools.hydra_utils import register_hydra_resolvers -from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm +from experimental.torch._compress.compress_test_utils import setup_test_model_and_data from modelopt.torch._compress import compress +from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( + convert_llama3_to_decilm, +) from modelopt.torch._compress.runtime import NativeDdpRuntime # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) @@ -64,42 +60,16 @@ def test_compress(project_root_path: Path, tmp_path: Path): def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): - register_hydra_resolvers() - - # - # The inputs for the compress() algorihm. - # - puzzle_dir = tmp_path - dataset_path = puzzle_dir / "dummy_dataset" - hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B" - with NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) as runtime: - # - # Test setup - # - if rank == 0: - # Setup puzzle_dir and dataset - setup_puzzle_dir(puzzle_dir) - save_dummy_dataset(dataset_path) - - # - # Step 1: Create and save a teacher model to compress - # This mimics the normal pipeline where we start with a Llama model - # - - # Create a small Llama model (not DeciLM) to match the normal conversion pipeline - tokenizer = create_tokenizer(project_root_path) - # TODO: change it to "ckpts/llama" once the conversion script is fixed - # Currently, the build replacement library step will fail with such a path. - llama_checkpoint_path = puzzle_dir / "ckpts/teacher" - create_and_save_small_llama_model( - llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer - ) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( + setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + ) - # Use the full conversion pipeline (matches normal usage) + # Convert the Llama model to DeciLM model. + if rank == 0: convert_llama3_to_decilm( input_dir=llama_checkpoint_path, output_dir=puzzle_dir / "ckpts/teacher", From f7d547fba005a6de3ff8015236166bad63ada88f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 12 Nov 2025 09:23:02 +0100 Subject: [PATCH 07/62] Add decilm modelling code (#505) ## What does this PR do? Add decilm modelling code --------- Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 12 + .../converters/convert_llama3_to_decilm.py | 3 +- .../decilm/deci_lm_hf_code/__init__.py | 15 + .../decilm/deci_lm_hf_code/block_config.py | 308 ++ .../deci_lm_hf_code/configuration_decilm.py | 210 ++ .../megatron_lm__mamba_mixer.py | 527 ++++ .../megatron_lm__megatron_tokenizer.py | 148 + .../deci_lm_hf_code/megatron_lm__tokenizer.py | 187 ++ .../decilm/deci_lm_hf_code/modeling_decilm.py | 2627 +++++++++++++++++ .../deci_lm_hf_code/tokenization_decilm.py | 195 ++ .../deci_lm_hf_code/tokenization_mistral.py | 374 +++ .../transformers_4_44_2__activations.py | 254 ++ .../transformers_4_44_2__cache_utils.py | 1447 +++++++++ ...ransformers_4_44_2__configuration_llama.py | 219 ++ ...ormers_4_44_2__modeling_attn_mask_utils.py | 498 ++++ ...g_flash_attention_utils_backward_compat.py | 363 +++ .../transformers_4_44_2__modeling_outputs.py | 1768 +++++++++++ ...ransformers_4_44_2__modeling_rope_utils.py | 574 ++++ .../transformers_4_44_2__pytorch_utils.py | 32 + .../transformers_4_51_3__cache_utils.py | 2535 ++++++++++++++++ ...ansformers_4_51_3__configuration_llama4.py | 447 +++ ...rmers_4_51_3__modeling_llama4_attention.py | 289 ++ .../decilm/deci_lm_hf_code/variable_cache.py | 213 ++ .../decilm/deci_lm_hf_code/vllm_yarn_utils.py | 210 ++ 24 files changed, 13454 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 598957f860..eec84b2b8a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,18 @@ repos: hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] + # See: commit hooks modifies block_config.py leading to test_compress.py failing (#25) · Issues · omniml / modelopt · GitLab + exclude: > + (?x)^( + modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py + )$ - id: ruff-format + exclude: > + (?x)^( + modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py + )$ - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.17.1 @@ -96,6 +107,7 @@ repos: examples/speculative_decoding/main.py| examples/speculative_decoding/medusa_utils.py| examples/speculative_decoding/server_generate.py| + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py| )$ # Default hook for Apache 2.0 in c/c++/cuda files diff --git a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py index d17e7ef74b..96b96f3510 100644 --- a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py +++ b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py @@ -24,9 +24,10 @@ from puzzle_tools.checkpoint_utils import copy_tokenizer from puzzle_tools.checkpoint_utils_hf import copy_deci_lm_hf_code from puzzle_tools.conversion_utils import convert_model_weights_to_decilm -from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig from transformers import LlamaConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig + """ example: diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py new file mode 100644 index 0000000000..d5eebfa352 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py @@ -0,0 +1,308 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +import dataclasses +import inspect +import warnings +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Optional, Type, Union, get_args, get_origin + + +@dataclass(frozen=True, kw_only=True) +class BaseDataclass: + """ + A dataclass base class with several utilities: + 1. Comparison via string representation. + 2. Initialization of dataclasses fields from dicts. + 3. Setting attributes even though it's frozen (but only inside __post_init__!) + """ + + def __eq__(self, other: "BaseDataclass") -> bool: + return str(self) == str(other) + + def __hash__(self) -> int: + return hash(str(self)) + + def __lt__(self, other: "BaseDataclass") -> bool: + return str(self) < str(other) + + def _force_setattr(self, name: str, value: Any) -> None: + """ + Set an attribute even in frozen dataclasses. + Use only inside __post_init__! + """ + assert _is_called_from_post_init(), ( + "_force_setattr should only be called from __post_init__, " + "if you need to change an attribute use dataclasses.replace " + "or create a new instance :)" + ) + object.__setattr__(self, name, value) + + def __post_init__(self): + """ + Init dataclass fields from dicts + """ + for field in dataclasses.fields(self): + field_dict = getattr(self, field.name) + if isinstance(field_dict, dict) and _is_dataclass_type(field.type): + dataclass_cls = _get_dataclass_type(field.type) + sub_fields = [field.name for field in dataclasses.fields(dataclass_cls)] + unsupported_fields = [ + field_name for field_name in field_dict.keys() if field_name not in sub_fields + ] + if len(unsupported_fields) > 0: + warnings.warn( + f"Removed unsupported fields {unsupported_fields} from {dataclass_cls}" + ) + + field_dict = {k: v for k, v in field_dict.items() if k not in unsupported_fields} + self._force_setattr(field.name, dataclass_cls(**field_dict)) + + +def _is_called_from_post_init() -> bool: + frame = inspect.currentframe() + while frame: + if frame.f_code.co_name == "__post_init__": + return True + frame = frame.f_back + return False + + +def _is_dataclass_type(tp: Type) -> bool: + """ + Like dataclasses.is_dataclass but also works for Optional[] and Union[] of a dataclass type + """ + try: + _get_dataclass_type(tp) + return True + except: + return False + + +def _get_dataclass_type(tp: Type) -> dataclass: + """ + If the given type is a dataclass, the function returns it. + If it is a Union[] or Optional[], the function extracts the first dataclass type. + If no dataclass type is found, the function raises a ValueError. + """ + origin = get_origin(tp) + if origin is Union: + for type_in_union in get_args(tp): + if dataclasses.is_dataclass(type_in_union): + return type_in_union + if dataclasses.is_dataclass(tp): + return tp + raise ValueError("Not a dataclass") + + +@dataclass(frozen=True, kw_only=True) +class SubblockConfig(BaseDataclass): + no_op: bool = False + replace_with_linear: bool = False + sparsify: Optional[list[str]] = None + weights_precision: Optional[str] = "bf16" + + def __post_init__(self): + super().__post_init__() + assert not (self.no_op and self.replace_with_linear) + if self.no_op: + self._force_setattr("sparsify", None) + + @abstractmethod + def to_blockconfig(self) -> "BlockConfig": + """ " + Convert to a block including this subblock only. + """ + ... + + +@dataclass(frozen=True, kw_only=True) +class MoEConfig(BaseDataclass): + """ + Configuration class for Mixture of Experts parameters. + """ + + num_local_experts: int = 8 + num_experts_per_tok: int = 1 + expert_intermediate_dim: int = 8192 + shared_expert_intermediate_dim: int = 8192 + # router_aux_loss_coef: float = 0.01 + # router_z_loss_coef: float = 0.0 # Optional z-loss coefficient + + def __post_init__(self): + # Validate the configuration + if self.num_local_experts <= 0: + raise ValueError(f"num_local_experts must be positive, got {self.num_local_experts}") + if self.num_experts_per_tok <= 0: + raise ValueError(f"top_k must be positive, got {self.top_k}") + if self.num_experts_per_tok > self.num_local_experts: + raise ValueError( + f"top_k ({self.top_k}) cannot be greater than num_local_experts ({self.num_local_experts})" + ) + # if self.router_aux_loss_coef < 0: + # raise ValueError(f"router_aux_loss_coef must be non-negative, got {self.router_aux_loss_coef}") + + +@dataclass(frozen=True, kw_only=True) +class MambaConfig(BaseDataclass): + state_dim: int + num_heads: int + head_dim: int + num_groups: int + + +@dataclass(frozen=True, kw_only=True) +class Llama4AttentionConfig(BaseDataclass): + attention_chunk_size: Optional[int] = None + use_rope: Optional[bool] = None + use_qk_norm: Optional[bool] = None + attn_scale: Optional[float] = None + floor_scale: Optional[float] = None + attn_temperature_tuning: Optional[bool] = None + attention_dropout: Optional[float] = None + + +@dataclass(frozen=True, kw_only=True) +class AttentionConfig(SubblockConfig): + n_heads_in_group: Optional[int] = None + window_length: Optional[int] = None + num_sink_tokens: Optional[int] = None + use_prefill_window_in_sink_attention: bool = False + unshifted_sink: bool = False + mamba: Optional[MambaConfig] = None + llama4: Optional[Llama4AttentionConfig] = None + + def __post_init__(self): + super().__post_init__() + + if self.no_op: + assert not self.replace_with_linear + assert not self.is_mamba + assert not self.is_llama4 + + if self.no_op or self.replace_with_linear or self.is_mamba: + for irrelevant_att in [ + "n_heads_in_group", + "window_length", + "num_sink_tokens", + "use_prefill_window_in_sink_attention", + "unshifted_sink", + "attention_chunk_size", + "attn_scale", + "floor_scale", + "attn_temperature_tuning", + "attention_dropout", + "use_qk_norm", + ]: + self._force_setattr(irrelevant_att, None) + else: + assert self.n_heads_in_group is not None + + if self.is_sink: + assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), ( + "Unshifted sink uses its own kind of explicit masking, not standard window. " + "Set use_prefill_window_in_sink_attention to False." + ) + assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), ( + "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True" + ) + + if self.is_llama4: + assert not self.is_sink, "Sink not support with Llama4 currently" + assert not self.is_sliding, "Sliding window not support with Llama4 currently" + assert not self.unshifted_sink, "Unshifted sink not support with Llama4 currently" + + def to_blockconfig(self) -> "BlockConfig": + return BlockConfig(attention=self, ffn=FFNConfig(no_op=True)) + + @property + def prefill_sliding_window(self) -> Optional[int]: + if self.window_length is not None: + if not self.is_sink or self.use_prefill_window_in_sink_attention: + return self.window_length + return None + + @property + def is_sliding(self) -> bool: + return self.prefill_sliding_window is not None + + @property + def is_sink(self) -> bool: + return (self.window_length is not None) and (self.num_sink_tokens is not None) + + @property + def is_mamba(self) -> bool: + return self.mamba is not None + + @property + def is_llama4(self) -> bool: + return self.llama4 is not None + + +@dataclass(frozen=True, kw_only=True) +class FFNConfig(SubblockConfig): + gated: Optional[bool] = ( + True # Gated Linear Unit e.g. SwiGLU or vanilla MLP (up -> activation -> down) + ) + hidden_act: Optional[str] = "silu" + moe: Optional[MoEConfig] = None + intermediate_size: Optional[int] = None + + def __post_init__(self): + super().__post_init__() + if self.no_op or self.replace_with_linear: + self._force_setattr("gated", None) + self._force_setattr("hidden_act", None) + self._force_setattr("moe", None) + self._force_setattr("intermediate_size", None) + elif self.is_moe: + self._force_setattr("gated", None) + self._force_setattr("hidden_act", None) + self._force_setattr("intermediate_size", None) + else: + assert self.intermediate_size is not None, ( + "Intermediate size must be provided for an FFN block" + ) + assert self.intermediate_size % 256 == 0, "Intermediate size must be divisible by 256" + + def to_blockconfig(self) -> "BlockConfig": + return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self) + + @property + def is_moe(self) -> bool: + return self.moe is not None + + +SUBBLOCK_CLS_DICT = { + "attention": AttentionConfig, + "ffn": FFNConfig, +} + + +@dataclass(frozen=True, kw_only=True) +class BlockConfig(BaseDataclass): + attention: Optional[AttentionConfig] = None + ffn: Optional[FFNConfig] = None + parallel_blocks: Optional[list["BlockConfig"]] = None + + def __post_init__(self): + super().__post_init__() + if (self.parallel_blocks is not None) and isinstance(self.parallel_blocks[0], dict): + initialized_block_configs = [ + BlockConfig(**block_config) for block_config in self.parallel_blocks + ] + self._force_setattr("parallel_blocks", initialized_block_configs) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py new file mode 100644 index 0000000000..c37b9adaf7 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import copy +import dataclasses +import warnings +from typing import Any + +from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available + +from .block_config import BlockConfig +from .transformers_4_44_2__configuration_llama import LlamaConfig + +# fakes imports to make AutoConfig infer dependencies +from .transformers_4_44_2__modeling_rope_utils import rope_config_validation +from .transformers_4_51_3__cache_utils import HybridChunkedCache +from .transformers_4_51_3__configuration_llama4 import Llama4Config + +# make sure that auto-formatting doesn't remove the fake imports +rope_config_validation +Llama4Config +HybridChunkedCache + + +class DeciLMConfig(LlamaConfig): + model_type = "nemotron-nas" + + # Mapping from global attribute names to their per-layer equivalents in block_configs + # Format: 'global_name': ('block_section', 'layer_name') + PER_LAYER_ATTRIBUTE_MAPPING = { + "intermediate_size": ("ffn", "intermediate_size"), + "num_key_value_heads": ( + "attention", + "n_heads_in_group", + ), # Note: derived value (num_heads / num_kv_heads) + "hidden_act": ("ffn", "hidden_act"), + "sliding_window": ("attention", "window_length"), # Note: different name! + } + + def __init__( + self, + block_configs: list[dict] | list[BlockConfig] | None = None, + position_embedding_type: str = "rope", + llama4_attn_implementation: str | None = None, + block_return_only_hidden_states: bool = False, + router_aux_loss_coef: float = 0.01, + router_z_loss_coef: float = 0.0, + output_router_logits: bool = False, + head_dim: int | None = 128, + o_proj_bias: bool = False, + **kwargs, + ): + self.block_configs: list[BlockConfig] = block_configs + if self.block_configs is not None: + if isinstance(self.block_configs[0], dict): + self.block_configs = [BlockConfig(**conf) for conf in self.block_configs] + + assert position_embedding_type in ["rope", "rope_llama4", "none", "mistral_yarn"] + self.position_embedding_type = position_embedding_type + if self.position_embedding_type == "none": + self.rope_theta = None + self.rope_scaling = None + + self.block_return_only_hidden_states = block_return_only_hidden_states + self.router_aux_loss_coef = router_aux_loss_coef + self.router_z_loss_coef = router_z_loss_coef + self.output_router_logits = output_router_logits + self.o_proj_bias = o_proj_bias + + self._choose_llama4_attn_implementation(llama4_attn_implementation) + attn_implementation = self._choose_llama3_attn_implementation(kwargs) + super().__init__(attn_implementation=attn_implementation, **kwargs) + self.head_dim = ( + head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + ) + + # Delete per-layer attributes after parent init (they should only exist in block_configs) + self._delete_per_layer_attributes() + + if self.block_configs is not None: + assert len(self.block_configs) == self.num_hidden_layers + + def _delete_per_layer_attributes(self): + """Delete per-layer attributes that should only exist in block_configs. + + These attributes are intentionally deleted AFTER super().__init__() to ensure + they don't exist at the global config level. Deleting them (rather than setting + to None) makes it clear they shouldn't be accessed globally. + """ + present_attrs = { + attr: getattr(self, attr) + for attr in self.PER_LAYER_ATTRIBUTE_MAPPING + if hasattr(self, attr) + } + if present_attrs: + warnings.warn( + f"Deleting global per-layer attributes (should only be in block_configs): {present_attrs}", + UserWarning, + stacklevel=3, + ) + for attr in self.PER_LAYER_ATTRIBUTE_MAPPING: + if hasattr(self, attr): + delattr(self, attr) + + def _choose_llama4_attn_implementation(self, llama4_attn_implementation): + self.llama4_attn_implementation = llama4_attn_implementation + if self.llama4_attn_implementation is None: + if is_torch_sdpa_available(): + _print_once("auto-setting llama4_attn_implementation to sdpa") + self.llama4_attn_implementation = "sdpa" + else: + _print_once("auto-setting llama4_attn_implementation to eager") + self.llama4_attn_implementation = "eager" + + def _choose_llama3_attn_implementation(self, kwargs: dict[str, Any]) -> str: + attn_implementation = kwargs.pop("attn_implementation", None) + if attn_implementation is None and is_flash_attn_2_available(): + _print_once("auto-setting attn_implementation (for Llama3 layers) to flash_attention_2") + attn_implementation = "flash_attention_2" + + if self.block_configs is not None: + using_unshifted_sink = any( + block_config.attention.unshifted_sink for block_config in self.block_configs + ) + if using_unshifted_sink and attn_implementation != "eager": + warnings.warn( + "Forcing attn_implementation='eager' since some attention layers use unshifted sink" + ) + attn_implementation = "eager" + return attn_implementation + + def to_dict(self) -> dict[str, Any]: + """Convert config to dictionary, removing per-layer-only attributes.""" + self_dict = super().to_dict() + if self.block_configs is not None: + self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs] + + # Remove global keys that should only exist per-layer in block_configs + for key in self.PER_LAYER_ATTRIBUTE_MAPPING: + self_dict.pop(key, None) + + return self_dict + + def set_block_configs(self, block_configs: list[BlockConfig]) -> "DeciLMConfig": + new_model_config = copy.deepcopy(self) + new_model_config.block_configs = block_configs + new_model_config.num_hidden_layers = len(block_configs) + return new_model_config + + def get_num_hidden_layers(self) -> int: + return self.num_hidden_layers + + def get_hidden_size(self) -> int: + return self.hidden_size + + def get_embedding_layer_name(self) -> str: + return "model.embed_tokens" + + def get_final_layer_norm_layer_name(self) -> str: + return "model.norm" + + def get_lm_head_layer_name(self) -> str: + return "lm_head" + + def get_layers_layer_name(self) -> str: + return "model.layers" + + def get_block_config(self, layer_idx: int | tuple[int, ...]) -> BlockConfig: + if isinstance(layer_idx, tuple) and len(layer_idx) == 1: + layer_idx = layer_idx[0] + + if isinstance(layer_idx, int): + return self.block_configs[layer_idx] + + external_layer_idx, internal_layer_idx = layer_idx + return self.block_configs[external_layer_idx].parallel_blocks[internal_layer_idx] + + def get_min_attention_chunk_size(self) -> int | None: + min_chunk_size = float("inf") + for block_config in self.block_configs: + if block_config.attention.llama4 is not None: + attention_chunk_size = block_config.attention.llama4.attention_chunk_size + if attention_chunk_size is not None: + min_chunk_size = min(min_chunk_size, attention_chunk_size) + + if min_chunk_size == float("inf"): + return None + return min_chunk_size + + +def _print_once(message: str): + if not hasattr(_print_once, "was_printed"): + _print_once.was_printed = set() + if message not in _print_once.was_printed: + _print_once.was_printed.add(message) + print(message) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py new file mode 100644 index 0000000000..76dbb3473b --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py @@ -0,0 +1,527 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Adapted from megatron.core.ssm.mamba_mixer.MambaMixer: +# https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/0b5140009fb9011eceaef6d36ea1181a8d176479/megatron/core/ssm/mamba_mixer.py + +# ruff: noqa: N803, N806 + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + from einops import rearrange, repeat + from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import ( + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + ) + + class MambaMixerMegatron(nn.Module): + """ + Args: + d_model: The hidden size of the model. + d_state: The state size of the SSM. + d_conv: The number of channels in the causal convolution. + conv_init: The initialization range for the causal convolution weights. + nheads: The number of Mamba heads. Used to calculate the expansion factor for the SSM + instead of the deprecated arg "expand". + headdim: The hidden size of each attention head. + ngroups: The number of attention heads. + A_init_range: The initialization range for the attention weights. + D_has_hdim: Whether the D parameter has the same number of dimensions as the hidden + state. + rmsnorm: Whether to use root mean square normalization. + norm_before_gate: Whether to apply normalization before the gating mechanism. + dt_min: The minimum value of the dt parameter. + dt_max: The maximum value of the dt parameter. + dt_init: The initialization value of the dt parameter. + dt_scale: The scaling factor for the dt parameter. + dt_init_floor: The minimum value of the dt parameter after initialization. + bias: Whether to use bias in the linear layers. + conv_bias: Whether to use bias in the causal convolution. + chunk_size: The chunk size for the fused kernel. + use_mem_eff_path: Whether to use the memory-efficient path for the Mamba model. + layer_number: The layer number of this Mamba layer. + """ + + def __init__( + self, + d_model, + d_state=256, + d_conv=4, + conv_init=None, + nheads=256, + headdim=64, + ngroups=8, + A_init_range=(1, 16), + D_has_hdim=False, + rmsnorm=True, + norm_before_gate=False, + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + bias=False, + conv_bias=True, + # Fused kernel and sharding options + chunk_size=128, + use_mem_eff_path=True, + layer_number=None, + ): + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.conv_init = conv_init + self.nheads = nheads + self.headdim = headdim + self.ngroups = ngroups + self.D_has_hdim = D_has_hdim + self.rmsnorm = rmsnorm + self.norm_before_gate = norm_before_gate + self.chunk_size = chunk_size + self.use_mem_eff_path = use_mem_eff_path + self.layer_number = layer_number + + self.d_inner = self.nheads * self.headdim + + self.tensor_model_parallel_size = 1 + assert self.d_inner % self.tensor_model_parallel_size == 0 + assert self.ngroups % self.tensor_model_parallel_size == 0 + assert self.nheads % self.tensor_model_parallel_size == 0 + assert not bias + assert not self.norm_before_gate + + self.d_inner_local = self.d_inner // self.tensor_model_parallel_size + self.ngroups_local = self.ngroups // self.tensor_model_parallel_size + self.nheads_local = self.nheads // self.tensor_model_parallel_size + + assert self.d_inner_local % self.ngroups_local == 0 + + # Assume sequence parallelism: input is already partitioned along the + # sequence dimension + self.in_proj = nn.Linear( + self.d_model, + self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads, # AB CD E + bias=False, + ) + + conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state # A CD + + # weight dim: [conv_dim, conv_dim, d_conv] + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + ) + + if self.conv_init is not None: + nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) + + self.activation = "silu" + self.act = nn.SiLU() + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.nheads_local) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Our initialization would set all Linear.bias to zero, + # need to mark this one as _no_reinit + self.dt_bias._no_reinit = True + # Just to be explicit. Without this we already don't + # put wd on dt_bias because of the check + + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads_local, dtype=torch.float32).uniform_(*A_init_range) + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter( + torch.ones( + self.d_inner_local if self.D_has_hdim else self.nheads_local, + ) + ) # Keep in fp32 + self.D._no_weight_decay = True + + if self.rmsnorm: + self.norm = RMSNormGated( + self.d_inner_local, + eps=1e-5, + group_size=self.d_inner_local // self.ngroups_local, + norm_before_gate=self.norm_before_gate, + ) + + # Assume sequence parallelism: input is partitioned along d_inner and + # output is partitioned along the sequence dimension + self.out_proj = nn.Linear( + self.d_inner, + self.d_model, + bias=False, + ) + + def forward(self, hidden_states, inference_params=None): + """ + hidden_states: (nL, B, D) / (L B D) + Returns: same shape as hidden_states + """ + _, batch, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + # assert not self.config.sequence_parallel + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, out_bias, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out, out_bias + + # (nheads_local) + A = -torch.exp(self.A_log.float()) + + # xz, _ = self.in_proj(hidden_states) # TransformerEngine also returns bias + xz = self.in_proj(hidden_states) + + # transpose: l b pd --> b l pd + xz = rearrange(xz, "l b d -> b l d").contiguous() + + if self.use_mem_eff_path and inference_params is None: + assert ssm_state is None + + if self.conv1d.bias is not None: + self.conv1d.bias.data_ptr() + + y = mamba_split_conv1d_scan_combined( + xz, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.dt_bias.float(), + A, + D=( + rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.D + ), + chunk_size=self.chunk_size, + activation=self.activation, + headdim=None if self.D_has_hdim else self.headdim, + ngroups=self.ngroups_local, + norm_before_gate=self.norm_before_gate, + ) + + if self.rmsnorm: + y = self.norm(y) + else: + z, xBC, dt = torch.split( + xz, + [ + self.d_inner_local, + self.d_inner_local + 2 * self.ngroups_local * self.d_state, + self.nheads_local, + ], + dim=-1, + ) + + # transpose: b l pd --> b pd l + xBC = rearrange(xBC, "b l d -> b d l").contiguous() + + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_( + F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) + ) # Update state (B D W) + + seqlen = xBC.size(2) + if causal_conv1d_fn is None: + xBC = self.act(self.conv1d(xBC)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + xBC = causal_conv1d_fn( + x=xBC, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # transpose b pd l --> b l pd + xBC = rearrange(xBC, "b d l -> b l d").contiguous() + + x, B, C = torch.split( + xBC, + [ + self.d_inner_local, + self.ngroups_local * self.d_state, + self.ngroups_local * self.d_state, + ], + dim=-1, + ) + + # TO DO Vijay: fuse most of the transposes with the GEMMS + x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim).contiguous() + dt = dt.contiguous() + B = rearrange(B, "b l (g n) -> b l g n", n=self.d_state).contiguous() + C = rearrange(C, "b l (g n) -> b l g n", n=self.d_state).contiguous() + z = rearrange(z, "b l (h p) -> b l h p", p=self.headdim).contiguous() + y = mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + self.chunk_size, + D=( + rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.D + ), + z=z if not self.rmsnorm else None, + dt_bias=self.dt_bias.float(), + dt_softplus=True, + return_final_states=ssm_state is not None, + ) + + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + + if self.rmsnorm: + y = rearrange(y, "b l h p -> b l (h p)").contiguous() + z = rearrange(z, "b l h p -> b l (h p)").contiguous() + y = self.norm(y, z) + else: + y = rearrange(y, "b l h p -> b l (h p)").contiguous() + + y = rearrange(y, "b l d -> l b d").contiguous() + # out, out_bias = self.out_proj(y) # TransformerEngine also returns bias + out = self.out_proj(y) + + return out + + def step(self, hidden_states, conv_state, ssm_state): + """ + Performs inference step for decoding + """ + # assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now" + dtype = hidden_states.dtype + assert hidden_states.shape[0] == 1, ( + "Only support decoding with 1 token at a time for now" + ) + + # l b d --> b d + hidden_states = hidden_states.squeeze(0) + + # b d_model --> b p(2d) + xz, _ = self.in_proj(hidden_states) + + z, xBC, dt = torch.split( + xz, + [ + self.d_inner_local, + self.d_inner_local + 2 * self.ngroups_local * self.d_state, + self.nheads_local, + ], + dim=-1, + ) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum( + conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 + ) # (B D) + if self.conv1d.bias is not None: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(dtype=dtype) + else: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x, B, C = torch.split( + xBC, + [ + self.d_inner_local, + self.ngroups_local * self.d_state, + self.ngroups_local * self.d_state, + ], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) + + # SSM step + if selective_state_update is None: + if self.ngroups_local > 1: + B = rearrange(B, "b (g n) -> b g n", n=self.d_state) + C = rearrange(C, "b (g n) -> b g n", n=self.d_state) + B = repeat(B, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) + C = repeat(C, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) + + dt = repeat(dt, "b h -> b (h p)", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> (h p)", p=self.headdim) + A = repeat(A, "h -> (h p) n", p=self.headdim, n=self.d_state) + D = repeat(self.D, "h -> (h p)", p=self.headdim) + + dt = F.softplus(dt + dt_bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + + dB_x = torch.einsum("bd,bdn,bd->bdn", dt, B, x) + ssm_state.copy_( + ssm_state * rearrange(dA, "b (h p) n -> b h p n", p=self.headdim) + + rearrange(dB_x, "b (h p) n -> b h p n", p=self.headdim) + ) + + y = torch.einsum( + "bdn,bdn->bd", + rearrange(ssm_state.to(dtype), "b h p n -> b (h p) n", p=self.headdim), + C, + ) + y = y + D.to(dtype) * x + if not self.rmsnorm: + y = y * self.act(z) # (B D) + else: + # Discretize A and B (b (g n)) + dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) + dA = torch.exp(dt * A) + x = rearrange(x, "b (h p) -> b h p", p=self.headdim) + dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) + ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) + y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) + y = y + rearrange(self.D.to(dtype), "h -> h 1") * x + y = rearrange(y, "b h p -> b (h p)") + if not self.rmsnorm: + y = y * self.act(z) # (B D) + else: + A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) + dt = repeat(dt, "b h -> b h p", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) + D = repeat(self.D, "h -> h p", p=self.headdim) + B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local) + C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local) + x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + y = selective_state_update( + ssm_state, + x_reshaped, + dt, + A, + B, + C, + D, + z=z if not self.rmsnorm else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + y = rearrange(y, "b h p -> b (h p)") + + if self.rmsnorm: + y = self.norm(y, z) + + # b pd --> b d + out, out_bias = self.out_proj(y) + return out.unsqueeze(0), out_bias, conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + """ + allocate inference cache + """ + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], + self.d_conv, + device=device, + dtype=conv_dtype, + ) + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, + self.nheads_local, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_number is not None + if self.layer_number not in inference_params.key_value_memory_dict: + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.nheads_local, + self.headdim, + self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_number] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_number] + # TO DO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state + +except ImportError as exception: + mamba_error_message = f"Cannot declare MambaMixer due to missing dependencies: {exception=}." + warnings.warn(mamba_error_message) + + # TODO: Investigate why this type ignore is needed + class MambaMixerMegatron(nn.Module): # type: ignore[no-redef] + def __init__(self, *args, **kwargs): + raise ImportError(mamba_error_message) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py new file mode 100644 index 0000000000..1b3840a300 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any + +import numpy + + +class MegatronTokenizer(ABC): + """Abstract class for tokenizer + + Absent a config or class-specific tracking of which objects are uniquely identifying, we must + include all key word arguments as unique identifiers + + Args: + tokenizer_paths (Tuple[str]): All tokenizer source paths or prefixes + + tokenizer_options (Dict[str, Any]): All tokenizer options + """ + + def __init__(self, *tokenizer_paths: str, **tokenizer_options: Any): + self.unique_identifiers = OrderedDict() + self.unique_identifiers["class"] = type(self).__name__ + self.unique_identifiers["tokenizer_path"] = list(tokenizer_paths) + for option in tokenizer_options: + self.unique_identifiers[option] = str(tokenizer_options[option]) + + self.unique_description = json.dumps(self.unique_identifiers, indent=4) + + super().__init__() + + @abstractmethod + def tokenize(self, text: str) -> numpy.ndarray: + """Convert text to embedding ids + + Args: + text (str): The text to convert + + Returns: + numpy.ndarray: The converted embedding ids + """ + + def detokenize(self, ids: numpy.ndarray) -> str: + """Convert embedding ids to text + + Args: + ids (numpy.ndarray): The ids to convert + + Returns: + str: The converted text + + Raises: + NotImplementedError: Non-abstract, optional method + """ + raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__)) + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token""" + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token""" + + @property + @abstractmethod + def vocab_size(self): + """The vocabulary size""" + + @property + def cls(self): + """The CLS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'cls'".format(type(self).__name__)) + + @property + def sep(self): + """The SEP token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'sep'".format(type(self).__name__)) + + @property + def pad(self): + """The PAD token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__)) + + @property + def eod(self): + """The EOD token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'eod'".format(type(self).__name__)) + + @property + def bos(self): + """The BOS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__)) + + @property + def eos(self): + """The EOS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__)) + + @property + def mask(self): + """The MASK token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'mask'".format(type(self).__name__)) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py new file mode 100644 index 0000000000..5c641d25b9 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Megatron tokenizers.""" + +import base64 +import json +from pathlib import Path + +from .megatron_lm__megatron_tokenizer import MegatronTokenizer + + +def reload_mergeable_ranks( + path: str, + max_vocab: int | None = None, +) -> dict[bytes, int]: + """ + Reload our tokenizer JSON file and convert it to Tiktoken format. + """ + assert path.endswith(".json") + + # reload vocab + with open(path) as f: + vocab = json.load(f) + assert isinstance(vocab, list) + print(f"Vocab size: {len(vocab)}") + if max_vocab is not None: + vocab = vocab[:max_vocab] + print(f"Cutting vocab to first {len(vocab)} tokens.") + + # build ranks + ranks: dict[bytes, int] = {} + for i, x in enumerate(vocab): + assert x.keys() == {"rank", "token_bytes", "token_str"} + assert x["rank"] == i + merge = base64.b64decode(x["token_bytes"]) + assert i >= 256 or merge == bytes([i]) + ranks[merge] = x["rank"] + + # sanity check + assert len(ranks) == len(vocab) + assert set(ranks.values()) == set(range(len(ranks))) + + return ranks + + +PATTERN_TIKTOKEN = ( + r"[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" +) +PATTERN_TIKTOKEN_V2 = ( + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+" + "|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*" + "|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +) + + +class CustomTikTokenizer(MegatronTokenizer): + def __init__( + self, + path: str, + pattern: str, + vocab_size: int, + num_special_tokens: int, + special_tokens: list[str] | None, + ): + super().__init__( + path, + pattern=pattern, + vocab_size=vocab_size, + num_special_tokens=num_special_tokens, + special_tokens=special_tokens, + ) + import tiktoken + + # if vocab_size is None: + # vocab_size = 2**17 # Fallback vocab size is 131072. + self._vocab_size = vocab_size + + special_tokens_default = ["", "", ""] + if special_tokens is None: + special_tokens = special_tokens_default.copy() + assert len(special_tokens) == len(set(special_tokens)), ( + f"Special tokens should be unique: {special_tokens}" + ) + assert len(special_tokens) <= num_special_tokens < self._vocab_size + assert set(special_tokens_default) <= set(special_tokens), ( + f"Custom special tokens should include {special_tokens_default}" + ) + + special_filler = [ + "".format(id=i) for i in range(len(special_tokens), num_special_tokens) + ] + if special_filler: + print(f"Adding special tokens {special_filler[0]}, ..., {special_filler[-1]}") + special_tokens = special_tokens + special_filler + assert len(set(special_tokens)) == len(special_tokens) == num_special_tokens, special_tokens + inner_vocab_size = self._vocab_size - num_special_tokens + + token_to_id_without_special_tokens = reload_mergeable_ranks( + path, max_vocab=inner_vocab_size + ) + # Create space for special tokens. + token_to_id_without_special_tokens = { + t: i + num_special_tokens for t, i in token_to_id_without_special_tokens.items() + } + + special_tokens = {t: i for i, t in enumerate(special_tokens)} + self._unk_id = special_tokens[""] + self._bos_id = special_tokens[""] + self._eos_id = special_tokens[""] + + # Create tiktoken model. + self._model = tiktoken.Encoding( + name=Path(path).parent.name, + pat_str=pattern, + mergeable_ranks=token_to_id_without_special_tokens, + special_tokens=special_tokens, + ) + + # Create final _id_to_token and _token_to_id data structures with special tokens inserted + # into appropriate locations. + assert set(token_to_id_without_special_tokens.keys()).isdisjoint(set(special_tokens.keys())) + self._token_to_id = token_to_id_without_special_tokens.copy() + self._token_to_id.update(special_tokens) + self._id_to_token = {v: k for k, v in self._token_to_id.items()} + assert set(range(self._vocab_size)) == set(self._id_to_token.keys()) + + @property + def bos(self) -> int: + return self._bos_id + + @property + def eos(self) -> int: + return self._eos_id + + @property + def unk(self) -> int: + return self._unk_id + + @property + def eod(self) -> int: + return self._eos_id + + @property + def vocab(self): + return self._token_to_id + + @property + def inv_vocab(self): + return self._id_to_token + + def tokenize(self, s: str, bos: bool = False, eos: bool = False) -> list[int]: + tokens = self._model.encode_ordinary(s) + if bos: + tokens = [self.bos, *tokens] + if eos: + tokens = [*tokens, self.eos] + + return tokens + + def detokenize(self, tokens: list[int]) -> str: + return self._model.decode(tokens) + + @property + def vocab_size(self) -> int: + return self._vocab_size + + @property + def encoder(self): + return self._token_to_id + + @property + def decoder(self): + return self._id_to_token diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py new file mode 100644 index 0000000000..808533d7f8 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py @@ -0,0 +1,2627 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved. +# +# This code for Nvidia's model is based on the Llama modeling code by HuggingFace, +# which is in turn based on EleutherAI's GPT-NeoX library and the GPT-NeoX and +# OPT implementations in this library. +# Sliding window code based on Gemma2 by Google. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import math + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import GenerationConfig +from transformers.generation.utils import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .block_config import AttentionConfig, FFNConfig, MambaConfig, MoEConfig +from .configuration_decilm import DeciLMConfig +from .megatron_lm__mamba_mixer import MambaMixerMegatron +from .transformers_4_44_2__activations import ACT2FN +from .transformers_4_44_2__cache_utils import Cache, StaticCache +from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter +from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import ( + _flash_attention_forward, +) +from .transformers_4_44_2__modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS +from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS +from .transformers_4_51_3__modeling_llama4_attention import Llama4TextAttention, Llama4TextConfig +from .variable_cache import VariableCache +from .vllm_yarn_utils import YaRNScalingRotaryEmbedding + +# from transformers.models.llama4.modeling_llama4 import Llama4TextL2Norm +MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[DeciLMConfig.model_type] = "DeciLMForCausalLM" +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeciLMConfig" + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or + a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be + as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Llama4TextL2Norm(torch.nn.Module): + def __init__(self, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) + + def extra_repr(self): + return f"eps={self.eps}" + + +class DeciLMRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeciLMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(DeciLMRMSNorm) + + +class DeciLMRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: DeciLMConfig | None = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`DeciLMRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_impl = "rope" if config is None else config.position_embedding_type + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + def _set_inv_freq_if_needed(self, device: torch.device) -> None: + is_missing_inv_freq = not hasattr(self, "inv_freq") + is_meta_mismatch = not is_missing_inv_freq and ( + str(device) != "meta" and self.inv_freq.is_meta + ) + + if is_missing_inv_freq or is_meta_mismatch: + with torch.device(device): + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, **self.rope_kwargs + ) + self.original_inv_freq = inv_freq + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len + and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + self._set_inv_freq_if_needed(x.device) + + if self.rope_impl == "rope_llama4": + return self.llama4_forward(x, position_ids) + else: + return self.llama3_forward(x, position_ids) + + def llama3_forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def llama4_forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # Convert to complex representation + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + freqs_cis = freqs_cis * self.attention_scaling + return freqs_cis + + +class DeciMistralYarnRotaryEmbedding(nn.Module): + def __init__(self, config: DeciLMConfig): + super().__init__() + self.config = config + self.rope_scaling = config.rope_scaling + self.base = config.rope_theta + self.rope_impl = config.position_embedding_type + self.head_size = config.hidden_size // config.num_attention_heads + self.yarn = YaRNScalingRotaryEmbedding( + head_size=self.head_size, + rotary_dim=self.head_size, + max_position_embeddings=self.rope_scaling["original_max_position_embeddings"], + base=self.base, + is_neox_style=True, + scaling_factor=self.rope_scaling["factor"], + beta_fast=self.rope_scaling["beta_fast"], + beta_slow=self.rope_scaling["beta_slow"], + dtype=torch.float32, + ) + self.attention_scaling = self.yarn.mscale + self.scaling_factor = self.rope_scaling["factor"] + self.rope_impl = "rope" if config is None else config.position_embedding_type + self.rope_impl = "even_odd" + + def _set_inv_freq_if_needed(self, device: torch.device) -> None: + is_missing_inv_freq = not hasattr(self, "inv_freq") + is_meta_mismatch = not is_missing_inv_freq and ( + str(device) != "meta" and self.inv_freq.is_meta + ) + + if is_missing_inv_freq or is_meta_mismatch: + with torch.device(device): + inv_freq = self.yarn._compute_inv_freq(self.scaling_factor) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def halves_forward(self, x, position_ids): + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + + self._set_inv_freq_if_needed(x.device) + + # print(f"halves_forward") + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + inv_freq_expanded = inv_freq_expanded.to(x.device) + # print(f"inv_freq_expanded: {inv_freq_expanded.device}") + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def forward(self, x, position_ids): + if self.rope_impl == "halves": + return self.halves_forward(x, position_ids) + elif self.rope_impl == "even_odd": + return self.even_odd_forward(x, position_ids) + else: + raise ValueError(f"Invalid rope implementation: {self.rope_impl}") + + def even_odd_forward(self, x, position_ids): + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + + self._set_inv_freq_if_needed(x.device) + + # print(f"even_odd_forward") + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # Convert to complex representation + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + freqs_cis = freqs_cis * self.attention_scaling + return freqs_cis + + +class DeciLMLinearScalingRotaryEmbedding(DeciLMRotaryEmbedding): + """DeciLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`DeciLMLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`DeciLMRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class DeciLMDynamicNTKScalingRotaryEmbedding(DeciLMRotaryEmbedding): + """DeciLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`DeciLMDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`DeciLMRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +rope_type_to_class = { + "default": DeciLMRotaryEmbedding, + "linear": DeciLMLinearScalingRotaryEmbedding, + "dynamic": DeciLMDynamicNTKScalingRotaryEmbedding, + "rope_llama4": DeciLMRotaryEmbedding, + "rope": DeciLMRotaryEmbedding, + "mistral_yarn": DeciMistralYarnRotaryEmbedding, +} + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, freqs_cis, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + freqs_cis (`torch.Tensor`): The frequency tensor. + a tuple of two tensors, cos and sin. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # print(f"applying first half-second half") + cos, sin = freqs_cis + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def vllm_apply_rotary_emb_torch( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + # print(f"freqs_cis: {freqs_cis.shape}, xq_: {xq_.shape}, xk_: {xk_.shape}") + xq_out = torch.view_as_real(xq_ * freqs_cis[:, None, :, :]).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis[:, None, :, :]).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class DeciLMGatedMLP(nn.Module): + def __init__( + self, + config: DeciLMConfig, + ffn_config: FFNConfig, + ): + super().__init__() + self.config = config + self.ffn_config = ffn_config + self.hidden_size = config.hidden_size + self.intermediate_size = ffn_config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[ffn_config.hidden_act] + + if ffn_config.sparsify is not None: + self.register_full_backward_hook(sparsity_backward_hook) + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + up_proj = torch.cat( + [F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class DeciLMVanillaMLP(nn.Module): + def __init__( + self, + config: DeciLMConfig, + ffn_config: FFNConfig, + ): + super().__init__() + self.config = config + self.ffn_config = ffn_config + self.hidden_size = config.hidden_size + self.intermediate_size = ffn_config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[ffn_config.hidden_act] + + if ffn_config.sparsify is not None: + self.register_full_backward_hook(sparsity_backward_hook) + + assert self.config.pretraining_tp == 1, ( + "Unsupported pretraining_tp != 1 for DeciLMVanillaMLP" + ) + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class DeciLMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: DeciLMConfig, + attention_config: AttentionConfig, + layer_idx: int | None = None, + ): + super().__init__() + self.config = config + self.attention_config = attention_config # type: AttentionConfig + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + if config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = attention_config.n_heads_in_group # DeciLM-specific code + self.num_key_value_heads = ( + self.num_heads // self.num_key_value_groups + ) # DeciLM-specific code + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + # llama4 attention specific + self.llama4_attn_config = attention_config.llama4 + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.o_proj_bias + ) + + if self.config.position_embedding_type in ["rope", "rope_llama4", "mistral_yarn"]: + # TO DO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = rope_type_to_class[self.config.position_embedding_type]( + config=self.config + ) + + if attention_config.sparsify is not None: + self.register_full_backward_hook(sparsity_backward_hook) + + self.is_llama4 = self.llama4_attn_config is not None + if ( + self.is_llama4 + and self.llama4_attn_config.use_qk_norm + and self.llama4_attn_config.use_rope + ): + self.qk_norm = Llama4TextL2Norm(self.config.rms_norm_eps) + + self.use_rope = ( + self.llama4_attn_config.use_rope + if self.is_llama4 + else self.config.position_embedding_type in ["rope", "mistral_yarn"] + ) + self.rope_impl = self.rotary_emb.rope_impl + self.apply_rope_fn = ( + apply_rotary_emb + if self.rope_impl in ["even_odd", "rope_llama4"] + else apply_rotary_pos_emb + ) + # self.apply_rope_fn = apply_rotary_emb + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # will become mandatory in v4.45 + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + + if self.config.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if self.use_rope: + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE " + "embeddings internally through `position_ids` (2D tensor with the indexes of the " + "tokens), to using externally computed `position_embeddings` (Tuple of tensors, " + "containing cos and sin). In v4.45 `position_ids` will be removed and " + "`position_embeddings` will be mandatory." + ) + freqs_cis = self.rotary_emb(value_states, position_ids) + else: + freqs_cis = position_embeddings + + query_states, key_states = self.apply_rope_fn(query_states, key_states, freqs_cis) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + if self.is_llama4: + query_states = self.apply_attention_scaling(input_shape, cache_position, query_states) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + # print(f"cache_position: {cache_position}") + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.config.pretraining_tp, dim=1 + ) + attn_output = sum( + [ + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def apply_attention_scaling(self, input_shape, cache_position, query_states): + # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers + if self.llama4_attn_config.attn_temperature_tuning and not self.use_rope: + attn_scales = ( + torch.log( + torch.floor( + (cache_position.float() + 1.0) / self.llama4_attn_config.floor_scale + ) + + 1.0 + ) + * self.llama4_attn_config.attn_scale + + 1.0 + ) + attn_scales = attn_scales.view((*input_shape, 1, 1)).transpose(1, 2) + query_states = (query_states * attn_scales).to(query_states.dtype) + return query_states + return query_states + + +class DeciLMFlashAttention2(DeciLMAttention): + """ + DeciLM flash attention module. This module inherits from `DeciLMAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is + # bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is + # used to handle this difference. + # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case + # q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + self.sliding_window = self.attention_config.prefill_sliding_window + + self.pre_attention_identity_query = nn.Identity() # for debugging hooks + self.pre_attention_identity_key = nn.Identity() # for debugging hooks + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # will become mandatory in v4.45 + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if self.config.position_embedding_type in ["rope", "mistral_yarn"]: + # llama4 doesn't use flash attention + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE " + "embeddings internally through `position_ids` (2D tensor with the indexes of the " + "tokens), to using externally computed `position_embeddings` (Tuple of tensors, " + "containing cos and sin). In v4.45 `position_ids` will be removed and " + "`position_embeddings` will be mandatory." + ) + freqs_cis = self.rotary_emb(value_states, position_ids) + else: + freqs_cis = position_embeddings + + query_states, key_states = self.apply_rope_fn(query_states, key_states, freqs_cis) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freq_cis) + # print(f"applying even odd rope") + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV + # cache to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeciLMRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + query_states = self.pre_attention_identity_query(query_states) + key_states = self.pre_attention_identity_key(key_states) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=self.sliding_window, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +DECILM_ATTENTION_CLASSES = { + "eager": DeciLMAttention, + "flash_attention_2": DeciLMFlashAttention2, +} + + +class DeciLMLlama4TextAttention(Llama4TextAttention): + def __init__(self, config: DeciLMConfig, layer_idx: int, attention_config: AttentionConfig): + llama4_text_config = Llama4TextConfig( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_attention_heads // attention_config.n_heads_in_group, + head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), + attn_scale=attention_config.llama4.attn_scale, + floor_scale=attention_config.llama4.floor_scale, + attn_temperature_tuning=attention_config.llama4.attn_temperature_tuning, + attention_dropout=attention_config.llama4.attention_dropout, + use_qk_norm=attention_config.llama4.use_qk_norm, + use_rope=attention_config.llama4.use_rope, + rms_norm_eps=config.rms_norm_eps, + attention_bias=config.attention_bias, + attn_implementation=config.llama4_attn_implementation, + rope_scaling=config.rope_scaling, + max_position_embeddings=config.max_position_embeddings, + attention_chunk_size=attention_config.llama4.attention_chunk_size, + ) + super().__init__(llama4_text_config, layer_idx, use_rope=attention_config.llama4.use_rope) + + +class DeciLMDecoderLayer(nn.Module): + # DeciLM-specific code + def __init__(self, config: DeciLMConfig, layer_idx: int | tuple[int, ...]): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.block_config = config.get_block_config(layer_idx) + + self.attention_config = self.block_config.attention + self.ffn_config = self.block_config.ffn + self.layer_idx = layer_idx + + if not self.attention_config.no_op: + self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.attention_config.replace_with_linear: + self.self_attn = DeciLMLinearAttention(config) + elif self.attention_config.is_mamba: + self.self_attn = DeciLMMambaMixer(config, self.attention_config.mamba) + elif not self.attention_config.is_llama4: + self.self_attn = DECILM_ATTENTION_CLASSES[config._attn_implementation]( + config=config, attention_config=self.attention_config, layer_idx=layer_idx + ) + else: + self.self_attn = DeciLMLlama4TextAttention(config, layer_idx, self.attention_config) + + if not (self.ffn_config.no_op or self.attention_config.is_mamba): + if self.ffn_config.hidden_act is None: + print(f"WARNING: FFN hidden_act is None for layer {layer_idx}") + + self.post_attention_layernorm = DeciLMRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + if self.ffn_config.replace_with_linear: + self.mlp = DeciLMLinearMLP(config) + elif self.ffn_config.is_moe: + self.mlp = DeciLMMoe(config, self.ffn_config) + else: + self.mlp = ( + DeciLMGatedMLP(config, self.ffn_config) + if self.ffn_config.gated + else DeciLMVanillaMLP(config, self.ffn_config) + ) + + self.is_sliding = self.attention_config.is_sliding + self.sliding_window = self.attention_config.prefill_sliding_window + self.return_only_hidden_states = self.config.block_return_only_hidden_states + + @property + def device(self): + try: + return next(self.parameters()).device + except StopIteration: + return None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool | None = False, + output_router_logits: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + **kwargs, + ) -> ( + tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None] + | torch.FloatTensor + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + paramz = list(self.parameters()) + device = paramz[0].device if len(paramz) > 0 else None + if isinstance(hidden_states, tuple): + # could happen when sewing kit sends the output of the previous layer + # to this layer without going through the model forward unpacking code. + # can be avoided by using config.block_return_only_hidden_states=True + hidden_states = hidden_states[0] + + hidden_states = hidden_states.to(device) + + if cache_position is not None: + cache_position = cache_position.to(device) + + if self.attention_config.llama4 is not None: + # chunk_size = self.attention_config.llama4.attention_chunk_size + # print(f"pre-llama4_update: {attention_mask=}") + # causal_mask, chunk_causal_mask = self._llama4_update_causal_mask( + # attention_mask, hidden_states, cache_position, past_key_value, output_attentions, use_cache=use_cache, + # ) + # attention_mask = causal_mask if (chunk_size is None) else chunk_causal_mask + # if (past_key_value is not None) and isinstance(attention_mask, BlockMask): + # print(f"pre-adjust: {attention_mask.shape=}") + # print(f"pre-adjust: {hidden_states.shape=}") + # print(f"pre-adjust: {past_key_value.get_seq_length()=}") + # q_len = hidden_states.shape[1] + # kv_len = past_key_value.get_seq_length() + # if kv_len == 0: + # kv_len = q_len + # print(f"pre-adjust: {kv_len=} {q_len=}") + # print(f"post-adjust: {attention_mask.shape=}") + assert self.config.llama4_attn_implementation != "flex_attention", ( + "We have a mask issue with flex attention" + ) + + causal_mask, chunk_causal_mask = self._llama4_update_causal_mask( + attention_mask, + hidden_states, + cache_position, + past_key_value, + output_attentions, + use_cache=use_cache, + ) + is_chunked = self.attention_config.llama4.attention_chunk_size is not None + attention_mask = ( + chunk_causal_mask if is_chunked and (chunk_causal_mask is not None) else causal_mask + ) + + else: + attention_mask = self._llama3_update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_value, output_attentions + ) + if self.attention_config.unshifted_sink and self.attention_config.is_sink: + attention_mask = self._unshifted_sink_mask( + attention_mask, + hidden_states, + self.attention_config.window_length, + self.attention_config.num_sink_tokens, + ) + else: + attention_mask = self._gemma2_window_mask( + attention_mask, hidden_states, past_key_value + ) + + self_attn_weights = None + present_key_value = past_key_value + router_logits = None + + if self.attention_config.no_op: + pass + elif self.attention_config.replace_with_linear or self.attention_config.is_mamba: + if self.attention_config.is_mamba: + assert past_key_value is None, "DeciLM does not support generation with Mamba yet" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + else: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_out = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states, self_attn_weights = attn_out[:2] + if len(attn_out) > 2: + present_key_value = attn_out[2] + + hidden_states = residual + hidden_states + + if not self.ffn_config.no_op: + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # Handle MoE layers differently as they return router logits + if self.ffn_config.is_moe: + hidden_states, router_logits = self.mlp(hidden_states) + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + if self.return_only_hidden_states: + return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits and router_logits is not None: + outputs += (router_logits,) + + return outputs + + def _gemma2_window_mask( + self, + attention_mask: torch.Tensor | None, + hidden_states: torch.Tensor, + past_key_value: VariableCache | None, + ) -> torch.Tensor | None: + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + # Flash-attn is a 2D tensor + if self.config._attn_implementation == "flash_attention_2": + if past_key_value is not None: # when decoding + attention_mask = attention_mask[:, -self.sliding_window :] + else: + min_dtype = torch.finfo(hidden_states.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + if attention_mask.shape[-1] <= 1: # when decoding + attention_mask = attention_mask[:, :, :, -self.sliding_window :] + return attention_mask + + def _unshifted_sink_mask( + self, + attention_mask: torch.Tensor, + hidden_states: torch.Tensor, + window_length: int, + num_sink_tokens: int | None, + ) -> torch.Tensor: + assert self.config._attn_implementation == "eager", ( + "Unshifted sink is only supported in 'eager' mode." + ) + assert attention_mask is not None, "The attention mask seems to not be prepared" + + attention_mask = attention_mask.clone() + min_dtype = torch.finfo(hidden_states.dtype).min + + if window_length == 0: + attention_mask = torch.full_like(attention_mask, fill_value=min_dtype) + else: + query_length = attention_mask.shape[-2] + is_decode = query_length == 1 + if is_decode: + attention_mask[:, :, :, :-window_length] = min_dtype + else: + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-window_length + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + + attention_mask[:, :, :, :num_sink_tokens] = 0 + return attention_mask + + def _llama3_update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is + # 2D and of dynamic length even when the static KV cache is used. This is an issue for + # torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic + # shapes. (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. + # A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`. + # See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache" + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not using_static_cache + and not output_attentions + ): + if ( + AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ) + and not self.is_sliding + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @torch.compiler.disable(recursive=False) # the operations in this method are not compilable + def _llama4_update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache | None, + output_attentions: bool = False, + chunked_attention_mask=None, + use_cache=True, + ): + attn_implementation = self.config.llama4_attn_implementation + + if attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return ( + attention_mask, + attention_mask, + ) # flash does not support chunked attn TODO support flash + return None, None + + if attn_implementation not in ["sdpa", "flex_attention", "eager"]: + return None, None + + sequence_length = input_tensor.shape[1] + cache_position = cache_position.to(self.device) + attention_chunk_size = self.attention_config.llama4.attention_chunk_size + if attention_chunk_size is None: + # let the function build some chunked mask, we won't use it since it's not a chunked + # attention layer. We still need to know the chunk size for this if statement that + # comes later on: if attn_implementation == "sdpa" and chunked_attention_mask is not None + # otherwise the mask dtype is wrong for sdpa :bufo-wat: + attention_chunk_size = self.config.get_min_attention_chunk_size() + if attention_chunk_size is None: + logger.warning_once( + "Could not infer attention_chunk_size since the model (or the model shard) " + "has no chunked attention, using 8192 as default for mask construction" + ) + attention_chunk_size = 8192 + + first_cache_position = cache_position[0] + + if past_key_values is not None: + full_cache_length = past_key_values.get_max_cache_shape() or sequence_length + else: + full_cache_length = ( + attention_mask.shape[-1] if attention_mask is not None else sequence_length + ) + + cond1 = first_cache_position >= attention_chunk_size + cond2 = (first_cache_position < attention_chunk_size) & ( + first_cache_position + sequence_length > attention_chunk_size + ) + key_length = ( + torch.where( + cond1, + attention_chunk_size + sequence_length - 1, + torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), + ) + if use_cache + else full_cache_length + ) + + if attn_implementation == "flex_attention": + raise NotImplementedError("DeciLM Llama4 does not support flex attention") + # if isinstance(attention_mask, torch.Tensor): + # offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0)) + # chunked_attention_mask = make_flex_block_causal_mask( + # attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets + # ) + # attention_mask = make_flex_block_causal_mask( + # attention_mask, + # query_length=sequence_length, + # key_length=full_cache_length, + # offsets=(first_cache_position, 0), + # ) + # return attention_mask, chunked_attention_mask + # if isinstance(attention_mask, BlockMask): + # return attention_mask, chunked_attention_mask + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + dtype, device = input_tensor.dtype, input_tensor.device + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=max(full_cache_length, attention_chunk_size), + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + min_dtype=torch.finfo(dtype).min, + ) + if full_cache_length > attention_chunk_size: + start_idx = max(first_cache_position - attention_chunk_size + 1, 0) + end_idx = start_idx + key_length + chunked_attention_mask = self.create_chunked_attention_mask( + attention_chunk_size, + start=start_idx, # same offset as with flex + end=end_idx, + device=device, + ) + + ### Deci: we added this code to patch a bug in transformers + if attention_mask is None: + if past_key_values is not None: + raise NotImplementedError("We only support attention_mask=None is prefill") + attention_mask = torch.ones( + input_tensor.shape[0], input_tensor.shape[1], device=device, dtype=torch.long + ) + + local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well + # It may be smaller than attention_chunk_size -> pad it + requires_padding = local_attention_mask.shape[-1] < attention_chunk_size + if requires_padding: + local_attention_mask = nn.functional.pad( + local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) + ) + # Depending on the padding, take the query tokens from the end or the cache_position + if not requires_padding: + chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] + else: + chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] + + chunked_attention_mask = chunked_attention_mask.expand( + input_tensor.shape[0], -1, -1, -1 + ) + chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] + if attn_implementation == "eager": + min_dtype = torch.finfo(dtype).min + chunked_attention_mask = torch.where( + chunked_attention_mask == 0, min_dtype, 0.0 + ).to(dtype) + + # print(f"{output_attentions=}") + + if ( + attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and attention_mask.ndim == 4 + and not output_attentions # Only unmask for 4d masks + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if attn_implementation == "sdpa" and chunked_attention_mask is not None: + chunked_attention_mask = chunked_attention_mask.bool() + causal_mask = causal_mask.bool() + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=first_cache_position, + is_training=self.training, + ): + causal_mask = None + return causal_mask, chunked_attention_mask + + def create_chunked_attention_mask( + self, attention_chunk_size: int, start: int, end: int, device: torch.device + ) -> torch.Tensor: + """ + Generate the following: + + 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | + '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | + '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | + 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | + '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | + '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | + + If the chunk size is 3. + This can just be appplied over the already created attention mask + """ + arange_vector = torch.arange(start, end, device=device) + block_pos = torch.abs( + arange_vector.unsqueeze(0) // attention_chunk_size + - arange_vector.unsqueeze(1) // attention_chunk_size + ) + token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) + mask = (block_pos == 0) & (token_pos <= 0) + return mask.to(device) + + +class DeciLMMultiDecoderLayer(nn.Module): + def __init__(self, config: DeciLMConfig, layer_idx: int): + super().__init__() + self.config = config + block_config = config.block_configs[layer_idx] + assert block_config.parallel_blocks is not None + num_parallel_blocks = len(block_config.parallel_blocks) + self.parallel_blocks = nn.ModuleList( + [ + DeciLMDecoderLayer(config, (layer_idx, internal_block_idx)) + for internal_block_idx in range(num_parallel_blocks) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + *args, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + block_outputs = [block(hidden_states, *args, **kwargs) for block in self.parallel_blocks] + output_hidden_states = [ + out[0].to(hidden_states.device) + if isinstance(out, tuple) + else out.to(hidden_states.device) + for out in block_outputs + ] + output_hidden_states = torch.stack(output_hidden_states, dim=0).sum(dim=0) + output_hidden_states = ( + output_hidden_states - (len(self.parallel_blocks) - 1) * hidden_states + ) + + if self.config.block_return_only_hidden_states: + return output_hidden_states + + other_outputs = block_outputs[0][1:] + outputs = (output_hidden_states, *other_outputs) + return outputs + + +DECILM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeciLMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeciLM Model outputting raw hidden-states without any specific head on top.", + DECILM_START_DOCSTRING, +) +class DeciLMPreTrainedModel(PreTrainedModel): + config_class = DeciLMConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeciLMDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True # all the _supports_... flags refer to the Llama3 layers + _supports_sdpa = False + _supports_flex_attn = False + _supports_cache_class = True + _supports_quantized_cache = False + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _prepare_generation_config( + self, + generation_config: GenerationConfig | None, + *args, + **kwargs, + ) -> tuple[GenerationConfig, dict]: + try: + from transformers import cache_utils + from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING + + need_setup_cache_classes_mapping = NEED_SETUP_CACHE_CLASSES_MAPPING + except Exception: + # older releases exposed it via generation.utils + need_setup_cache_classes_mapping = {} + + # DeciLM-specific code + generation_config, model_kwargs = super()._prepare_generation_config( + generation_config, *args, **kwargs + ) + # New transformers version, can reach only through cache_utils + if need_setup_cache_classes_mapping == {}: + cache_utils._CACHE_IMPLEMENTATION_MAPPING["variable"] = VariableCache + else: + need_setup_cache_classes_mapping["variable"] = VariableCache + + generation_config.cache_implementation = "variable" + return generation_config, model_kwargs + + +DECILM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`VariableCache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + If passed to the forward function, past_key_values must be a VariableCache object (see imports). + For generation purposes, this is already handled inside model.generate(). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare DeciLM Model outputting raw hidden-states without any specific head on top.", + DECILM_START_DOCSTRING, +) +class DeciLMModel(DeciLMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`] + + Args: + config: DeciLMConfig + """ + + def __init__(self, config: DeciLMConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + ( + DeciLMDecoderLayer(config, layer_idx) + if (config.block_configs[layer_idx].parallel_blocks is None) + else DeciLMMultiDecoderLayer(config, layer_idx) + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.config.position_embedding_type in ["rope", "rope_llama4", "mistral_yarn"]: + self.rotary_emb = rope_type_to_class[self.config.position_embedding_type](config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_final_layer_norm(self): + return self.norm + + def set_final_layer_norm(self, value): + self.norm = value + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + output_router_logits: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | BaseModelOutputWithPast: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + is_legacy_cache_format = (past_key_values is not None) and type( + past_key_values + ).__name__ != "VariableCache" + # We use the __name__ instead of isinstance to support weird use cases + # (init cache from a checkpoint dir and use it with local code) + if is_legacy_cache_format: + raise NotImplementedError( + "DeciLMModel does not support legacy cache format, please use a newer " + "transformers version or use VariableCache explicitly (see import in this file)." + ) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + # use default device + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = None + if hasattr(self, "rotary_emb"): + # rotary emb is created all devices, so we need to move position_ids to the correct device + some_param = next(self.parameters()) + position_ids = position_ids.to(some_param.device) + cache_position = cache_position.to(some_param.device) + faux_hidden_states = position_ids.to(some_param.dtype) + position_embeddings = self.rotary_emb(faux_hidden_states, position_ids) + # print(f'START {position_embeddings.device=}') # HF hook will change the device + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + if self.config.block_return_only_hidden_states: + hidden_states = layer_outputs + next_decoder_cache = past_key_values + + else: + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Extract router logits if they exist + if output_router_logits: + router_logits_index = -1 # Router logits are always the last element + if len(layer_outputs) > (2 if output_attentions else 1) + ( + 1 if use_cache else 0 + ): + all_router_logits += (layer_outputs[router_logits_index],) + + # Final layer norm + hidden_states = hidden_states.to(next(self.parameters()).device) + hidden_states = self.norm(hidden_states) + + # Add the last hidden state + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # Set the next cache + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + outputs = (hidden_states, next_cache, all_hidden_states, all_self_attns) + if output_router_logits: + outputs += (all_router_logits,) + return outputs + + # Handle different return types based on whether router logits are requested + if output_router_logits and all_router_logits: + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + else: + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + """ + The DeciLM Model transformer with a sequence classification head on top (linear layer). + + [`DeciLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DECILM_START_DOCSTRING, +) +class DeciLMForSequenceClassification(DeciLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeciLMModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + elif input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits, *transformer_outputs[1:]) + return (loss, *output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The DeciLM Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DECILM_START_DOCSTRING, +) +class DeciLMForQuestionAnswering(DeciLMPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->DeciLM + def __init__(self, config): + super().__init__(config) + self.transformer = DeciLMModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + start_positions: torch.LongTensor | None = None, + end_positions: torch.LongTensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | QuestionAnsweringModelOutput: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits, *outputs[2:]) + return (total_loss, *output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The DeciLM Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + DECILM_START_DOCSTRING, +) +class DeciLMForTokenClassification(DeciLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeciLMModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, *outputs[2:]) + return (loss, *output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +######################################################################## +# DeciLM-specific code +######################################################################## + + +def _find_multiple(n: int, k: int) -> int: + # DeciLM-specific code + if n % k == 0: + return n + return n + k - (n % k) + + +class DeciLMMoe(nn.Module): + """ + Implementation of Mixture of Experts module for DeciLM. + Equivalent to Llama4 MoE but implemented more frugally. + """ + + def __init__(self, config: DeciLMConfig, ffn_config: FFNConfig): + super().__init__() + self.config = config + self.ffn_config = ffn_config + + # MoE parameters + assert ffn_config.moe is not None, "MoE configuration must be provided to use DeciLMMoe" + self.moe_config: MoEConfig = ffn_config.moe + self.hidden_dim = config.hidden_size + self.num_experts_per_tok = self.moe_config.num_experts_per_tok + self.num_local_experts = self.moe_config.num_local_experts + self.expert_intermediate_dim = self.moe_config.expert_intermediate_dim + self.shared_expert_intermediate_dim = self.moe_config.shared_expert_intermediate_dim + + # Initialize experts and router + routed_expert_ffn_config = FFNConfig( + intermediate_size=self.expert_intermediate_dim, + ) + + self.experts = nn.ModuleList( + [ + DeciLMGatedMLP(config, routed_expert_ffn_config) + for _ in range(self.num_local_experts) + ] + ) + + self.router = nn.Linear(config.hidden_size, self.num_local_experts, bias=False) + + # Initialize shared expert as a standard MLP + shared_expert_ffn_config = FFNConfig( + intermediate_size=self.moe_config.shared_expert_intermediate_dim + ) + self.shared_expert = DeciLMGatedMLP(config, shared_expert_ffn_config) + + if ffn_config.sparsify is not None: + self.register_full_backward_hook(sparsity_backward_hook) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through the MoE layer. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch, seq_len, hidden_dim) + + Returns: + tuple: + - torch.Tensor: Output tensor of shape (batch, seq_len, hidden_dim) + - torch.Tensor: Router scores for loss computation + """ + router_logits = self.router(hidden_states) + + routed_out = self.forward_routed_experts(hidden_states, router_logits) + + shared_out = self.shared_expert(hidden_states) + + moe_out = routed_out + shared_out + + return moe_out, router_logits + + def forward_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ) -> torch.Tensor: + """ + For each expert: + 1. Build the input to the expert based on the router mask + 2. Run the expert + 3. Add the result of the expert into the total MoE result using += + """ + router_top_values, router_indices = torch.topk( + router_logits, self.num_experts_per_tok, dim=-1 + ) + router_scores = torch.sigmoid(router_top_values.float()).to(hidden_states.dtype) + + routed_out = torch.zeros_like(hidden_states) + for i_expert in range(self.num_local_experts): + expert_mask = router_indices == i_expert + if expert_mask.any(): + is_token_routed_to_this_expert = expert_mask.any(dim=-1) + relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] + relevant_scores = router_scores[expert_mask] + expert_in = relevant_hidden_states * relevant_scores.unsqueeze(-1) + + expert_out = self.experts[i_expert](expert_in).to(hidden_states.device) + + routed_out[is_token_routed_to_this_expert, :] += expert_out + + return routed_out + + def extra_repr(self) -> str: + return ( + f"(MoE): num_local_experts={self.num_local_experts}, " + f"expert_intermediate_dim={self.expert_intermediate_dim}," + ) + + +class DeciLMLinearMLP(nn.Module): + # DeciLM-specific code + def __init__( + self, + config: DeciLMConfig, + ): + super().__init__() + self.linear_mlp = nn.Linear( + in_features=config.hidden_size, out_features=config.hidden_size, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_mlp.forward(x) + + +class DeciLMLinearAttention(nn.Module): + # DeciLM-specific code + def __init__( + self, + config: DeciLMConfig, + ): + super().__init__() + self.linear_attn = nn.Linear( + in_features=config.hidden_size, out_features=config.hidden_size, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_attn.forward(x) + + +def sparsity_backward_hook(*args, **kwargs): + raise NotImplementedError( + "No support for sparsity when training HF DeciLM (inference is ok though)" + ) + + +class DeciLMMambaMixer(nn.Module): + def __init__( + self, + config: DeciLMConfig, + mamba_config: MambaConfig, + ): + super().__init__() + self.mamba_mixer = MambaMixerMegatron( + d_model=config.hidden_size, + d_state=mamba_config.state_dim, + nheads=mamba_config.num_heads, + headdim=mamba_config.head_dim, + ngroups=mamba_config.num_groups, + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + x = x.permute([1, 0, 2]) # MambaMixerMegatron expects [Sequence, Batch, Embedding] + out = self.mamba_mixer(x) + out = out.permute([1, 0, 2]) # go back to [Batch, Sequence, Embedding] + return out + + +class LMHead(nn.Linear): + """ + Special class to allow FSDP wrapping without affecting other Linear layers in the model. + """ + + +class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: DeciLMConfig): + super().__init__(config) + self.model = DeciLMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def compute_router_aux_loss(self, router_logits): + """ + Computes the auxiliary loss for router logits. + This encourages load balancing across experts. + + Args: + router_logits: List of router logits tensors from each MoE layer + Each tensor has shape [batch_size, sequence_length, num_experts] + + Returns: + Auxiliary loss tensor + """ + aux_loss = torch.tensor(0.0, device=router_logits[0].device) + + for layer_idx, layer_router_logits in enumerate(router_logits): + router_probs = torch.softmax(layer_router_logits, dim=-1) + + # Mean routing probability across batch and sequence dimensions + mean_prob = router_probs.mean(dim=[0, 1]) + + # Compute auxiliary loss: combination of load balancing and importance loss + # Load balancing loss: variance of expert usage probabilities (should be uniform) + num_experts = mean_prob.size(0) + ideal_prob = 1.0 / num_experts + balance_loss = torch.sum((mean_prob - ideal_prob) ** 2) + + # Add this layer's auxiliary loss to the total + aux_loss = aux_loss + balance_loss + + # Average over all layers + if len(router_logits) > 0: + aux_loss = aux_loss / len(router_logits) + + return aux_loss + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + output_router_logits: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Return: + """ + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + # Extract model outputs based on return type + if isinstance(outputs, MoeModelOutputWithPast): + hidden_states = outputs.last_hidden_state + router_logits = outputs.router_logits + elif return_dict: + hidden_states = outputs.last_hidden_state + router_logits = None # No router logits in this case + else: + hidden_states = outputs[0] + router_logits = outputs[4] if output_router_logits and len(outputs) > 4 else None + + # Generate logits + logits = self.lm_head(hidden_states) + logits = logits.float() + + # Calculate loss if labels are provided + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + # Calculate router aux loss if router logits are present + if router_logits is not None and self.config.router_aux_loss_coef > 0: + aux_loss = self.compute_router_aux_loss(router_logits) + loss = loss + aux_loss * self.config.router_aux_loss_coef + + # Handle non-dict return + if not return_dict: + output = (logits,) + if isinstance(outputs, tuple): + output += outputs[1:] # Add all other outputs + return (loss, *output) if loss is not None else output + + # Different output types for MoE vs regular model + if router_logits is not None: + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values if return_dict else outputs[1], + hidden_states=outputs.hidden_states + if return_dict + else outputs[2] + if output_hidden_states + else None, + attentions=outputs.attentions + if return_dict + else outputs[3] + if output_attentions + else None, + router_logits=router_logits, + ) + else: + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values if return_dict else outputs[1], + hidden_states=outputs.hidden_states + if return_dict + else outputs[2] + if output_hidden_states + else None, + attentions=outputs.attentions + if return_dict + else outputs[3] + if output_attentions + else None, + ) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py new file mode 100644 index 0000000000..14c840b8b1 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Only needed for DeciLM models that use Megatron tokenizers. +DeciLM models that use Llama tokenizers do not need external code. +""" + +import json +import os +from pathlib import Path +from typing import Literal + +from transformers import PreTrainedTokenizer +from transformers.dynamic_module_utils import custom_object_save +from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE, AddedToken + +from .megatron_lm__megatron_tokenizer import ( + MegatronTokenizer, # fake import to make AutoTokenizer infer the dependency +) +from .megatron_lm__tokenizer import PATTERN_TIKTOKEN, PATTERN_TIKTOKEN_V2, CustomTikTokenizer + +MegatronTokenizer # make sure that auto-formatting doesn't remove the import + + +class MegatronTikTokenizer(PreTrainedTokenizer): + vocab_files_names: dict[str, str] = {"vocab_file": "tiktoken_vocab.json"} + model_input_names: list[str] = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file: str, + tiktoken_pattern: Literal["v1", "v2"], + vocab_size: int, + tiktoken_num_special_tokens: int, + tiktoken_special_tokens: list[str] | None, + add_bos_token: bool = False, # nm5 does not use bos token + add_eos_token: bool = False, # nm5 does not use eos token + **unused_kwargs, + ): + assert "chat_template" not in unused_kwargs, ( + "We enforce the Nemotron5 chat template from the code, " + "please do not provide a chat_template in the tokenizer_config.json file" + ) + + pattern = PATTERN_TIKTOKEN if tiktoken_pattern == "v1" else PATTERN_TIKTOKEN_V2 + self._tokenizer = CustomTikTokenizer( + path=vocab_file, + pattern=pattern, + vocab_size=vocab_size, + num_special_tokens=tiktoken_num_special_tokens, + special_tokens=tiktoken_special_tokens, + ) + + eos_token = self._tokenizer.detokenize([self._tokenizer.eos]) + bos_token = self._tokenizer.detokenize([self._tokenizer.bos]) + self.vocab = self._tokenizer.vocab + super().__init__( + eos_token=AddedToken(eos_token, normalized=False, special=True), + bos_token=AddedToken(bos_token, normalized=False, special=True), + pad_token=AddedToken(eos_token, normalized=False, special=True), + ) + + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.chat_template = NEMOTRON5_CHAT_TEMPLATE + + self._vocab_file_contents = Path(vocab_file).read_text() + self._tokenizer_config = { + "tiktoken_pattern": tiktoken_pattern, + "vocab_size": vocab_size, + "tiktoken_num_special_tokens": tiktoken_num_special_tokens, + "tiktoken_special_tokens": tiktoken_special_tokens, + "add_bos_token": add_bos_token, + "add_eos_token": add_eos_token, + "tokenizer_class": "MegatronTikTokenizer", + "auto_map": { + "AutoTokenizer": ["tokenization_decilm.MegatronTikTokenizer", None], + }, + } + + def get_vocab(self) -> dict[str, int]: + """to satisfy PreTrainedTokenizer.__init__()""" + return self.vocab + + def tokenize(self, text: str, **kwargs) -> list[str]: + return [text] + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + is_single_token = isinstance(tokens, str) + if is_single_token: + text = tokens + else: + assert len(tokens) == 1 + text = tokens[0] + + ids = self._tokenizer._model.encode(text, allowed_special="all") + + if is_single_token: + assert len(ids) == 1, ( + f"Asked to convert a single token to its id, but it's not a single token: encode('{tokens}') = {ids}" + ) + return ids[0] + else: + return ids + + def convert_ids_to_tokens( + self, ids: int | list[int], skip_special_tokens: bool = False + ) -> str | list[str]: + is_single_id = isinstance(ids, int) + if is_single_id: + ids = [ids] + + if skip_special_tokens: + ids = [idd for idd in ids if idd not in (self.eos_token_id, self.bos_token_id)] + + text = self._tokenizer.detokenize(ids) + + if is_single_id: + return text + else: + return [text] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """Taken from LlamaTokenizer""" + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def save_pretrained( + self, + save_directory: str | os.PathLike, + legacy_format: bool | None = None, + filename_prefix: str | None = None, + push_to_hub: bool = False, + **kwargs, + ) -> tuple[str, ...]: + assert legacy_format is None, "Unsupported" + assert filename_prefix is None, "Unsupported" + assert not push_to_hub, "Unsupported" + + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + tokenizer_config_path = save_directory / TOKENIZER_CONFIG_FILE + tokenizer_config_path.write_text(json.dumps(self._tokenizer_config, indent=2)) + + vocab_files_name = self.vocab_files_names["vocab_file"] + vocab_file_path = save_directory / vocab_files_name + vocab_file_path.write_text(self._vocab_file_contents) + + custom_object_save(self, save_directory) + + return str(tokenizer_config_path), str(vocab_file_path) + + +NEMOTRON5_CHAT_TEMPLATE = """{% if messages[0].role != "system" %} + {% set messages = [{"role": "system", "content": ""}] + messages %} +{% endif %} +{% for message in messages %} + {% if message.role == "system" %} +System +{{ message.content }} + {% elif message.role == "user" %} +User +{{ message.content }} + {% elif message.role == "assistant" %} +Assistant +{{ message.content }} + {% endif %} +{% endfor %} +{% if add_generation_prompt %} +Assistant +{% else %} + +{% endif %}""" diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py new file mode 100644 index 0000000000..e67674a092 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py @@ -0,0 +1,374 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Based on https://github.com/vllm-project/vllm/blob/739e03b3449a7f3b0a81ebc30b9555305d914e2d/vllm/transformers_utils/tokenizers/mistral.py +# mypy: ignore-errors + +import os +import re +import sys +from pathlib import Path +from shutil import copyfile +from typing import TYPE_CHECKING, Any + +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + +if TYPE_CHECKING: + from mistral_common.protocol.instruct.request import ChatCompletionRequest + +logger = logging.get_logger(__name__) + + +def _called_from_vllm() -> bool: + frame = sys._getframe(1) + while frame: + mod = frame.f_globals.get("__name__", "") + if mod == "vllm" or mod.startswith("vllm."): + return True + frame = frame.f_back + return False + + +class HFAdaptedMistralTokenizer(PreTrainedTokenizer): + """ + In order to save the tokenizer, do the following: + ``` + # from import HFAdaptedMistralTokenizer + # from mistral_common.tokens.tokenizers.base import SpecialTokens + HFAdaptedMistralTokenizer.register_for_auto_class("AutoTokenizer") + tokenizer = HFAdaptedMistralTokenizer("", chat_template="dummy") + tokenizer.add_special_tokens( + {"additional_special_tokens": [v.value for _, v in SpecialTokens.__members__.items()]} + ) + tokenizer.save_pretrained("") + ``` + """ + + vocab_files_names = {"path_indicator": "tokenizer_config.json"} + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + path_indicator: str, + unk_token: str | None = None, + bos_token: str | None = None, + eos_token: str | None = None, + pad_token: str | None = None, + add_bos_token: bool = True, + add_eos_token: bool = False, + clean_up_tokenization_spaces: bool = False, + **kwargs, + ): + path_indicator: Path = Path(path_indicator) + if path_indicator.name == "tokenizer_config.json": + path_indicator = path_indicator.parent + if path_indicator.is_dir(): + tokenizer_file_name = _find_tokenizer_file(os.listdir(path_indicator)) + tokenizer_file = str(path_indicator / tokenizer_file_name) + else: + tokenizer_file = path_indicator + self._mistral_tokenizer_path = str(tokenizer_file) + + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer as MistralTokenizer + + self._mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file) + self._instruct_tokenizer = self._mistral_tokenizer.instruct_tokenizer + + # Copied from https://github.com/patrickvonplaten/vllm/blob/6cca3d8c330e169bbf386561c441ca5f3879cf85/vllm/transformers_utils/tokenizers/mistral.py + self.version: int = int( + self._instruct_tokenizer.tokenizer.version.value.split("v")[-1].split("m")[0] + ) + + tokenizer_ = self._instruct_tokenizer.tokenizer + from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer + + self.is_tekken = isinstance(tokenizer_, Tekkenizer) + from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer + + self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) + if self.is_tekken: + # Make sure special tokens will not raise + tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE + elif self.is_spm: + pass + else: + raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") + + self._vocab = tokenizer_.vocab() + # Convert to a Dict[str, int] to match protocol, but this is a lossy + # conversion. There may be multiple token ids that decode to the same + # string due to partial UTF-8 byte sequences being converted to � + self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} + self._tokenizer = tokenizer_ + self._max_token_id = self.vocab_size - 1 + self.vocab = self._vocab_dict + + bos_token = ( + bos_token + if bos_token + else AddedToken( + self._tokenizer._vocab[self._tokenizer.bos_id], + normalized=False, + special=True, + ) + ) + eos_token = ( + eos_token + if eos_token + else AddedToken( + self._tokenizer._vocab[self._tokenizer.eos_id], + normalized=False, + special=True, + ) + ) + unk_token = ( + unk_token + if unk_token + else AddedToken( + self._tokenizer._vocab[self._tokenizer.unk_id], + normalized=False, + special=True, + ) + ) + pad_token = ( + pad_token + if pad_token + else AddedToken( + self._tokenizer._vocab[self._tokenizer.pad_id], + normalized=False, + special=True, + ) + ) + + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + + self._in_vllm = _called_from_vllm() + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def vocab_size(self): + """Returns vocab size""" + return self._tokenizer.n_words + + def get_vocab(self): + """Returns vocab as a dict""" + return self._vocab_dict + + def tokenize( + self, + text: str, + pair: str | None = None, + add_special_tokens: bool | None = None, + **kwargs, + ) -> list[str]: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + if add_special_tokens is None: + bos = self._add_bos_token + eos = self._add_eos_token + else: + bos = add_special_tokens + eos = add_special_tokens + + input_ids = [] + parts = self.tokens_trie.split(text) + + in_vllm_chat_completion_mode = False + if ( + self._in_vllm + and len(parts) > 1 + and parts[0] == SpecialTokens.bos.value + and parts[1] == SpecialTokens.begin_inst.value + ): + # This is a dangerous hack to make the tokenizer work with vLLM. + # It means we are in chat completion mode. + bos = False + eos = False + in_vllm_chat_completion_mode = True + + if os.environ.get("HF_TOKENIZE_FORCE_NO_SPECIAL_TOKENS", "0") == "1": + bos = False + eos = False + + if not self._in_vllm or in_vllm_chat_completion_mode: + for part in parts: + if part in self.additional_special_tokens and part in self._vocab_dict: + input_ids.append(self._convert_token_to_id(part)) + else: + input_ids.extend(self._tokenizer.encode(part, bos=bos, eos=eos)) + else: + # Doesn't tokenize special tokens properly, but this is the behavior of vLLM when we are in completion mode. + input_ids = self._tokenizer.encode(text, bos=bos, eos=eos) + + if os.environ.get("HF_TOKENIZE_ABUSE", "1") == "1": + # A lot faster than the other option + return input_ids + else: + return [self._convert_id_to_token(token_id) for token_id in input_ids] + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + if len(tokens) > 0 and isinstance(tokens[0], int): + return tokens + return super().convert_tokens_to_ids(tokens) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self._vocab_dict[token] + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + piece = self._tokenizer.id_to_piece(index) + return piece if isinstance(piece, str) else piece.value + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + if self.is_tekken: + tokens = [ + t + for t in tokens + if (t is SpecialTokens.tool_calls or t not in self._tokenizer._all_special_tokens) + ] + + if any(isinstance(t, bytes) for t in tokens): + # we need to encode and decode all tokens again + shift = self._tokenizer.num_special_tokens + + def _token_to_id(t: str): + t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t + try: + return shift + self._tokenizer._tekken_token2id_nospecial[t_bytes] + except KeyError: + logger.warning( + "Failed to convert token %s to id, replacing with ", + t_bytes, + ) + return self._tokenizer.unk_id + + ids = [_token_to_id(t) for t in tokens] + decoded = self._tokenizer.decode(ids) + else: + decoded = "".join(tokens) + else: + # make sure certain special tokens like Tool calls are + # not decoded + special_tokens = {SpecialTokens.tool_calls} + regular_tokens: list[str] = [] + decoded_list = [] + + for token in tokens: + if token in special_tokens: + if regular_tokens: + decoded_list.append(self._tokenizer.decode(regular_tokens)) + regular_tokens = [] + decoded_list.append(token) + else: + regular_tokens.append(token) + + if regular_tokens: + decoded_list.append(self._tokenizer.decode(regular_tokens)) # type: ignore[no-untyped-call] + + decoded = "".join(decoded_list) + + return decoded + + def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: + """ + Use this method to save the full tokenizer file. + """ + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join(save_directory, "tekken.json") + + if os.path.abspath(self._mistral_tokenizer_path) != os.path.abspath(out_vocab_file): + copyfile(self._mistral_tokenizer_path, out_vocab_file) + + return (out_vocab_file,) + + def apply_chat_template( + self, + conversation: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + tokenize: bool = True, + **kwargs, + ) -> list[int]: + request = _make_mistral_chat_completion_request(conversation, tools) + encoded = self._mistral_tokenizer.encode_chat_completion(request) + if tokenize: + # encode-decode to get clean prompt + return encoded.tokens + else: + return encoded.text + + +def _find_tokenizer_file(files: list[str]): + file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") + + matched_files = [file for file in files if file_pattern.match(file)] + if len(matched_files) > 1: + raise OSError( + f"Found {len(matched_files)} files matching the " + f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral " + f"tokenizer is present in {files}." + ) + elif len(matched_files) == 0: + raise OSError( + f"Found {len(matched_files)} files matching the " + f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " + f"tokenizer is present in {files}." + ) + + return matched_files[0] + + +def _make_mistral_chat_completion_request( + messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None +) -> "ChatCompletionRequest": + last_message = messages[-1] + if last_message["role"] == "assistant": + last_message["prefix"] = True + + # mistral-common requires AssistantMessage content to be string [1]. + # + # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 + for message in messages: + if message.get("role") == "assistant": + content = message.get("content") + if isinstance(content, list): + content = "\n".join(chunk.get("text") for chunk in content) + message["content"] = content + + # The Mistral client, in comparison to the OpenAI client, requires the + # "parameters" dict to be present, even if it's empty. + if tools: + for function in [tool["function"] for tool in tools if tool["type"] == "function"]: + if function.get("parameters") is None: + function["parameters"] = {} + + from mistral_common.protocol.instruct.request import ChatCompletionRequest + + return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py new file mode 100644 index 0000000000..6c964dbfc1 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py @@ -0,0 +1,254 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import OrderedDict + +import torch +from packaging import version +from torch import Tensor, nn + +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class PytorchGELUTanh(nn.Module): + """ + A fast C implementation of the tanh approximation of the GeLU activation function. See + https://arxiv.org/abs/1606.08415. + + This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical + match due to rounding errors. + """ + + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.12.0"): + raise ImportError( + f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " + "PytorchGELUTanh. Please upgrade torch." + ) + + def forward(self, input: Tensor) -> Tensor: + return nn.functional.gelu(input, approximate="tanh") + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class GELUActivation(nn.Module): + """ + Original Implementation of the GELU activation function in Google BERT repo when initially created. For + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional + Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, use_gelu_python: bool = False): + super().__init__() + if use_gelu_python: + self.act = self._gelu_python + else: + self.act = nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class FastGELUActivation(nn.Module): + """ + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + + +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(1.702 * input) + + +class ClippedGELUActivation(nn.Module): + """ + Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as + it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://arxiv.org/abs/2004.09602. + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, min: float, max: float): + if min > max: + raise ValueError(f"min should be < max (got min: {min}, max: {max})") + + super().__init__() + self.min = min + self.max = max + + def forward(self, x: Tensor) -> Tensor: + return torch.clip(gelu(x), self.min, self.max) + + +class AccurateGELUActivation(nn.Module): + """ + Applies GELU approximation that is faster than default and more accurate than QuickGELU. See: + https://github.com/hendrycks/GELUs + + Implemented along with MEGA (Moving Average Equipped Gated Attention) + """ + + def __init__(self): + super().__init__() + self.precomputed_constant = math.sqrt(2 / math.pi) + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3)))) + + +class MishActivation(nn.Module): + """ + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also + visit the official repository for the paper: https://github.com/digantamisra98/Mish + """ + + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.9.0"): + self.act = self._mish_python + else: + self.act = nn.functional.mish + + def _mish_python(self, input: Tensor) -> Tensor: + return input * torch.tanh(nn.functional.softplus(input)) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class LinearActivation(nn.Module): + """ + Applies the linear activation function, i.e. forwarding input directly to output. + """ + + def forward(self, input: Tensor) -> Tensor: + return input + + +class LaplaceActivation(nn.Module): + """ + Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See + https://arxiv.org/abs/2209.10655 + + Inspired by squared relu, but with bounded range and gradient for better stability + """ + + def forward(self, input, mu=0.707107, sigma=0.282095): + input = (input - mu).div(sigma * math.sqrt(2.0)) + return 0.5 * (1.0 + torch.erf(input)) + + +class ReLUSquaredActivation(nn.Module): + """ + Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 + """ + + def forward(self, input): + relu_applied = nn.functional.relu(input) + squared = torch.square(relu_applied) + return squared + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "gelu": GELUActivation, + "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), + "gelu_fast": FastGELUActivation, + "gelu_new": NewGELUActivation, + "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "gelu_pytorch_tanh": PytorchGELUTanh, + "gelu_accurate": AccurateGELUActivation, + "laplace": LaplaceActivation, + "leaky_relu": nn.LeakyReLU, + "linear": LinearActivation, + "mish": MishActivation, + "quick_gelu": QuickGELUActivation, + "relu": nn.ReLU, + "relu2": ReLUSquaredActivation, + "relu6": nn.ReLU6, + "sigmoid": nn.Sigmoid, + "silu": nn.SiLU, + "swish": nn.SiLU, + "tanh": nn.Tanh, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +def get_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") + + +# For backwards compatibility with: from activations import gelu_python +gelu_python = get_activation("gelu_python") +gelu_new = get_activation("gelu_new") +gelu = get_activation("gelu") +gelu_fast = get_activation("gelu_fast") +quick_gelu = get_activation("quick_gelu") +silu = get_activation("silu") +mish = get_activation("mish") +linear_act = get_activation("linear") diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py new file mode 100644 index 0000000000..83d7251dda --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py @@ -0,0 +1,1447 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +import copy +import importlib.metadata +import json +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from packaging import version + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torchdynamo_compiling, logging + + +logger = logging.get_logger(__name__) + + +class Cache(torch.nn.Module): + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states, if there is any.""" + raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + + +@dataclass +class CacheConfig: + """ + Base class for cache configs + """ + + cache_implementation: None + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__(self) -> None: + super().__init__() + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave( + repeats, dim=0 + ) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + +class OffloadedCache(DynamicCache): + """ + A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. + Useful for generating from models with very long context. + + In addition to the default CUDA stream, where all forward() computations happen, + this class uses another stream, the prefetch stream, which it creates itself. + Since scheduling of operations on separate streams happens independently, this class uses + the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. + The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to + ensure the eviction is scheduled after all computations on that cache are finished. + """ + + def __init__(self) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("OffloadedCache can only be used with a GPU") + super().__init__() + self.original_device = [] + self.prefetch_stream = torch.cuda.Stream() + self.beam_idx = None # used to delay beam search operations + + def prefetch_layer(self, layer_idx: int): + "Starts prefetching the next layer cache" + if layer_idx < len(self): + with torch.cuda.stream(self.prefetch_stream): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to( + device, non_blocking=True + ) + + def evict_previous_layer(self, layer_idx: int): + "Moves the previous layer cache to the CPU" + if len(self) > 2: + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(self) + self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to( + "cpu", non_blocking=True + ) + self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to( + "cpu", non_blocking=True + ) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." + if layer_idx < len(self): + # Evict the previous layer if necessary + torch.cuda.current_stream().synchronize() + self.evict_previous_layer(layer_idx) + # Load current layer cache to its original device if not already there + original_device = self.original_device[layer_idx] + self.prefetch_stream.synchronize() + key_tensor = self.key_cache[layer_idx] + value_tensor = self.value_cache[layer_idx] + # Now deal with beam search ops which were delayed + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(original_device) + key_tensor = key_tensor.index_select(0, self.beam_idx) + value_tensor = value_tensor.index_select(0, self.beam_idx) + # Prefetch the next layer + self.prefetch_layer((layer_idx + 1) % len(self)) + return (key_tensor, value_tensor) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Saves the beam indices and reorders the cache when the tensor is back to its device.""" + # We delay this operation until the tensors are back to their original + # device because performing torch.index_select on the CPU is very slow + del self.beam_idx + self.beam_idx = beam_idx.clone() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.original_device.append(key_states.device) + self.evict_previous_layer(layer_idx) + else: + key_tensor, value_tensor = self[layer_idx] + self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError + # if a method is not supposed to be supported in a subclass we should set it to None + from_legacy_cache = None + + to_legacy_cache = None + + +class SinkCache(Cache): + """ + A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to + generate beyond the length of its context window, without losing fluency in the conversation. As it discards past + tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Parameters: + window_length (`int`): + The length of the context window. + num_sink_tokens (`int`): + The number of sink tokens. See the original paper for more information. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__(self, window_length: int, num_sink_tokens: int) -> None: + super().__init__() + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + @staticmethod + def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> torch.Tensor: + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence + original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] + shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] + shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_rerotation_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.window_length + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, + `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the + rotation as the tokens are shifted. + + Return: + A tuple containing the updated key and value states. + """ + # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models + # with partially rotated position embeddings, like Phi or Persimmon. + sin = cache_kwargs.get("sin") + cos = cache_kwargs.get("cos") + partial_rotation_size = cache_kwargs.get("partial_rotation_size") + using_rope = cos is not None and sin is not None + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the sin/cos cache, which holds sin/cos values for all possible positions + if using_rope and layer_idx == 0: + # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove + # after all RoPE models have a llama-like cache utilization. + if cos.dim() == 2: + self._cos_cache = cos + self._sin_cache = sin + else: + if self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + else: + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : + ] + + # On RoPE models, we need to recompute the Key rotation as the tokens are shifted + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, + self._cos_cache[: self.window_length], + self._sin_cache[: self.window_length], + ) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb( + keys_to_keep, rerotation_cos, rerotation_sin + ) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : + ] + self.value_cache[layer_idx] = torch.cat( + [sink_values, values_to_keep, value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None + ) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if max_cache_len is None else max_cache_len + ) + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + for idx in range(config.num_hidden_layers): + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + # Notes: + # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case + # it is not needed anyway) + # 2. `torch.export()` requires mutations to be registered as buffers. + if not is_torchdynamo_compiling(): + self.register_buffer( + f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device) + ) + self.register_buffer( + f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device) + ) + new_layer_key_cache = getattr(self, f"key_cache_{idx}") + new_layer_value_cache = getattr(self, f"value_cache_{idx}") + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + cache_position = cache_kwargs.get("cache_position") + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + # return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + return self._seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.max_cache_len + + def reset(self): + self._seen_tokens = 0 + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class SlidingWindowCache(StaticCache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None + ) -> None: + super().__init__(config, max_batch_size, max_cache_len, device, dtype) + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + max_cache_len = min(config.sliding_window, max_cache_len) + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + + # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) + if cache_position.shape[0] > self.max_cache_len: + k_out = key_states[:, :, -self.max_cache_len :, :] + v_out = value_states[:, :, -self.max_cache_len :, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones( + self.max_cache_len, dtype=torch.long, device=value_states.device + ).cumsum(0) + cache_position = cache_position.clamp(0, self.max_cache_len - 1) + to_shift = cache_position >= self.max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len + + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + try: + cache_position.to(device=k_out.device) + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + + return k_out, v_out + + def get_max_length(self) -> Optional[int]: + # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is + return None + + def reset(self): + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class EncoderDecoderCache(Cache): + """ + Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and + cross-attention caches. + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") + >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") + + >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") + + >>> # Prepare cache classes for encoder and decoder and pass it to model's forward + >>> self_attention_cache = DynamicCache() + >>> cross_attention_cache = DynamicCache() + >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + + """ + + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): + super().__init__() + self.self_attention_cache = self_attention_cache + self.cross_attention_cache = cross_attention_cache + + self.is_updated = {} + for layer_idx in range(len(cross_attention_cache.key_cache)): + self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.self_attention_cache) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), + self.cross_attention_cache.to_legacy_cache(), + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "EncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.self_attention_cache.key_cache) <= layer_idx: + return 0 + return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + if hasattr(self.self_attention_cache, "reset"): + self.self_attention_cache.reset() + if hasattr(self.cross_attention_cache, "reset"): + self.cross_attention_cache.reset() + elif not hasattr(self.self_attention_cache, "reset") and not hasattr( + self.cross_attention_cache, "reset" + ): + raise ValueError( + "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " + "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " + f"Got {self.self_attention_cache.__str__()} for the self attention cache and " + f"{self.cross_attention_cache.__str__()} for the cross attention cache." + ) + for layer_idx in self.is_updated: + self.is_updated[layer_idx] = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + self.self_attention_cache.reorder_cache(beam_idx) + self.cross_attention_cache.reorder_cache(beam_idx) + + def check_dynamic_cache(self, method: str): + if not ( + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) + ): + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." + ) + + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) + + def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) + + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out + + @classmethod + def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() + for idx in range(len(splits[0])): + layer_keys = torch.cat( + [current.self_attention_cache.key_cache[idx] for current in splits], dim=0 + ) + layer_values = torch.cat( + [current.self_attention_cache.value_cache[idx] for current in splits], dim=0 + ) + self_attention_cache.update(layer_keys, layer_values, idx) + + layer_keys = torch.cat( + [current.cross_attention_cache.key_cache[idx] for current in splits], dim=0 + ) + layer_values = torch.cat( + [current.cross_attention_cache.value_cache[idx] for current in splits], dim=0 + ) + cross_attention_cache.update(layer_keys, layer_values, idx) + return cls(self_attention_cache, cross_attention_cache) + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) + + +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`, *optional*, defaults to `"cpu"`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None + ) -> None: + super().__init__() + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) + self.is_sliding = torch.tensor( + [not bool(i % 2) for i in range(config.num_hidden_layers)], + dtype=torch.bool, + device=device, + ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + global_cache_shape = ( + max_batch_size, + self.num_key_value_heads, + max_cache_len, + self.head_dim, + ) + sliding_cache_shape = ( + max_batch_size, + self.num_key_value_heads, + min(config.sliding_window, max_cache_len), + self.head_dim, + ) + for i in range(config.num_hidden_layers): + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + if cache_position.shape[0] > max_cache_len: + k_out = key_states[:, :, -max_cache_len:, :] + v_out = value_states[:, :, -max_cache_len:, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, max_cache_len - 1) + to_shift = cache_position >= max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % max_cache_len + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + return k_out, v_out + + def _static_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + sliding_window = cache_kwargs.get("sliding_window") + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + if sliding_window: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_length(self) -> Optional[int]: + # in theory there is no limit because the sliding window size is fixed + # no matter how long the sentence is + return self.max_cache_len + + def get_seq_length(self, layer_idx: Optional[int] = 0): + return None + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + dtype (*optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Attributes: + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + intermediate_size: (`int`): + Model's intermediate_size taken from config. + ssm_state_size: (`int`): + Model's state_size taken from config. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config + conv_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv = outputs.past_key_values + ``` + """ + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Optional[str] = None, + **kwargs, + ): + self.dtype = dtype + self.max_batch_size = max_batch_size + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + self.ssm_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=dtype, + ) + + torch._dynamo.mark_static_address(self.conv_states) + torch._dynamo.mark_static_address(self.ssm_states) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py new file mode 100644 index 0000000000..461996f742 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from .transformers_4_44_2__modeling_rope_utils import rope_config_validation + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py new file mode 100644 index 0000000000..7257800678 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py @@ -0,0 +1,498 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + + +@dataclass +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Examples: + + ```python + >>> import torch + >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + >>> converter = AttentionMaskConverter(True) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) + tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) + ``` + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + is_causal: bool + sliding_window: int + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype, + device: Union[torch.device, "str"] = "cpu", + ) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + dtype: torch.dtype, + key_value_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + + if causal_4d_mask is not None: + expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) + + # expanded_attn_mask + causal_4d_mask can cause some overflow + expanded_4d_mask = expanded_attn_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + mask.masked_fill_(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def _unmask_unattended( + expanded_mask: torch.FloatTensor, + min_dtype: float, + ): + # fmt: off + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 + + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. + + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. + + For example, if `expanded_mask` is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified `expanded_mask` will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # fmt: on + if expanded_mask.dtype == torch.bool: + raise ValueError( + "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." + ) + + return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) + + @staticmethod + def _ignore_causal_mask_sdpa( + attention_mask: Optional[torch.Tensor], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, + is_training: bool = False, + ) -> bool: + """ + Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the `attention_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + + _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + key_value_length = query_length + past_key_values_length + + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + ignore_causal_mask = False + + if attention_mask is None: + # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or + # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). + # Thus, we only set `ignore_causal_mask = True` if the model is set to training. + # + # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). + if ( + (is_training or not is_tracing) + and (query_length == 1 or key_value_length == query_length) + and (sliding_window is None or key_value_length < sliding_window) + ): + ignore_causal_mask = True + elif sliding_window is None or key_value_length < sliding_window: + if len(attention_mask.shape) == 4: + return False + elif (is_training or not is_tracing) and torch.all(attention_mask == 1): + if query_length == 1 or key_value_length == query_length: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. + + return ignore_causal_mask + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + return attention_mask + + +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. + + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if ignore_causal_mask: + expanded_4d_mask = None + elif attention_mask is None: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + else: + if attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + expanded_4d_mask = attention_mask + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) + + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + if not is_tracing and expanded_4d_mask.device.type == "cuda": + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min + ) + + return expanded_4d_mask + + +def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + _, key_value_length = mask.shape + tgt_len = tgt_len if tgt_len is not None else key_value_length + + is_tracing = ( + torch.jit.is_tracing() + or isinstance(mask, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. + if not is_tracing and torch.all(mask == 1): + return None + else: + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _create_4d_causal_attention_mask( + input_shape: Union[torch.Size, Tuple, List], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` + + Args: + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + device (`int`): + The torch device the created mask shall have. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = past_key_values_length + input_shape[-1] + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device + ) + + return attention_mask diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py new file mode 100644 index 0000000000..9e9fb46ca4 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py @@ -0,0 +1,363 @@ +# coding=utf-8 +# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import inspect +import os +from typing import Optional, Tuple, Union + + +import torch +import torch.nn.functional as F + +from functools import lru_cache +import importlib.metadata +import importlib.util +from packaging import version + +from transformers.utils import is_flash_attn_2_available + + +if is_flash_attn_2_available(): + try: + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn import flash_attn_func, flash_attn_varlen_func + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + except ImportError: + raise "Unable to import flash_attn" + + +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check if the package spec exists and grab its version to avoid importing a local directory + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + if return_version: + return package_exists, package_version + else: + return package_exists + + +@lru_cache() +def is_flash_attn_greater_or_equal(library_version: str): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa2_from_position_ids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cummulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cummulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def _flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: bool = None, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_top_left_mask (`bool`, defaults to `False`): + flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + softcap (`float`, *optional*): + Softcap for the attention logits, used e.g. in gemma2. + deterministic (`bool`, *optional*): + Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. + """ + if not use_top_left_mask: + causal = is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. + causal = is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if is_flash_attn_greater_or_equal("2.4.1"): + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + + if softcap is not None: + flash_kwargs["softcap"] = softcap + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): + batch_size = query_states.size(0) + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return attn_output diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py new file mode 100644 index 0000000000..aa9f07b879 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py @@ -0,0 +1,1768 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from transformers.utils import ModelOutput + + +@dataclass +class BaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoECausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden + states terms, to train a MoE model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + z_loss for the sparse modules. + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + z_loss: torch.FloatTensor = None + aux_loss: torch.FloatTensor = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoEModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_probs: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as + Mixture of Expert's router hidden states terms, to train a MoE model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_probs: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Seq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqMoEModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key, + value states of the self-attention and the cross-attention layers if model is used in encoder-decoder + setting. Only relevant if `config.is_decoder = True`. + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqMoEOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts + models. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + encoder_z_loss: torch.FloatTensor = None + decoder_z_loss: torch.FloatTensor = None + encoder_aux_loss: torch.FloatTensor = None + decoder_aux_loss: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class NextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided): + Next sequence prediction (classification) loss. + logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class TokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class QuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: torch.FloatTensor = None + end_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: torch.FloatTensor = None + end_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of semantic segmentation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageClassifierOutput(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DepthEstimatorOutput(ModelOutput): + """ + Base class for outputs of depth estimation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`): + Predicted depth for each pixel. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + predicted_depth: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageSuperResolutionOutput(ModelOutput): + """ + Base class for outputs of image super resolution models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed images, possibly upscaled. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Wav2Vec2BaseModelOutput(ModelOutput): + """ + Base class for models that have been trained with the Wav2Vec2 loss objective. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): + Sequence of extracted feature vectors of the last convolutional layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + extract_features: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class XVectorOutput(ModelOutput): + """ + Output type of [`Wav2Vec2ForXVector`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): + Classification hidden states before AMSoftmax. + embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): + Utterance embeddings used for vector similarity-based retrieval. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + embeddings: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BackboneOutput(ModelOutput): + """ + Base class for outputs of backbones. + + Args: + feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): + Feature maps of the stages. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`, + depending on the backbone. + + Hidden-states of the model at the output of each stage plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Only applicable if the backbone uses attention. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + feature_maps: Tuple[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndProjection(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`. + + Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + projection_state: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Seq2SeqSpectrogramOutput(ModelOutput): + """ + Base class for sequence-to-sequence spectrogram outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Spectrogram generation loss. + spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): + The predicted spectrogram. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + spectrogram: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqTSModelOutput(ModelOutput): + """ + Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up + sequential decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Shift values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to shift back to the original magnitude. + scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Scaling values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to rescale back to the original magnitude. + static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): + Static features of each time series' in a batch which are copied to the covariates at inference time. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + static_features: Optional[torch.FloatTensor] = None + + +@dataclass +class Seq2SeqTSPredictionOutput(ModelOutput): + """ + Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the + chosen distribution. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided): + Distributional loss. + params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`): + Parameters of the chosen distribution. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Shift values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to shift back to the original magnitude. + scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Scaling values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to rescale back to the original magnitude. + static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): + Static features of each time series' in a batch which are copied to the covariates at inference time. + """ + + loss: Optional[torch.FloatTensor] = None + params: Optional[Tuple[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + static_features: Optional[torch.FloatTensor] = None + + +@dataclass +class SampleTSPredictionOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Args: + sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`): + Sampled values from the chosen distribution. + """ + + sequences: torch.FloatTensor = None + + +@dataclass +class MaskedImageModelingOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py new file mode 100644 index 0000000000..761c2b6402 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py @@ -0,0 +1,574 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_linear_scaling_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + factor = rope_kwargs["factor"] + elif config is not None: + factor = config.rope_scaling["factor"] + + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + # Then applies linear scaling to the frequencies. + # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so + # applying scaling to the inverse frequencies is equivalent. + inv_freq /= factor + return inv_freq, attention_factor + + +def _compute_dynamic_ntk_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length, used to update the dynamic RoPE at inference time. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + max_position_embeddings = rope_kwargs["max_position_embeddings"] + factor = rope_kwargs["factor"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + attention_factor = 1.0 # Unused in this type of RoPE + + # seq_len: default to max_position_embeddings, e.g. at init time + seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings + + # Compute the inverse frequencies + base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, attention_factor + + +def _compute_longrope_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with LongRoPE scaling. Please refer to the + [original implementation](https://github.com/microsoft/LongRoPE) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + # No need to keep BC with longrope, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got " + f"{rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + long_factor = config.rope_scaling["long_factor"] + short_factor = config.rope_scaling["short_factor"] + factor = config.rope_scaling.get("factor") + attention_factor = config.rope_scaling.get("attention_factor") + + # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if hasattr(config, "original_max_position_embeddings"): + max_position_embeddings = config.original_max_position_embeddings + expanded_max_position_embeddings = config.max_position_embeddings + factor = expanded_max_position_embeddings / max_position_embeddings + else: + max_position_embeddings = config.max_position_embeddings + expanded_max_position_embeddings = max_position_embeddings * factor + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if factor <= 1.0: + attention_factor = 1.0 + else: + attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) + + # Compute the inverse frequencies -- scaled based on the target sequence length + if expanded_max_position_embeddings > max_position_embeddings: + ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) + inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim + inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) + + return inv_freq, attention_factor + + +def _compute_llama3_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama, attention_factor + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "linear": _compute_linear_scaling_rope_parameters, + "dynamic": _compute_dynamic_ntk_parameters, + "yarn": _compute_yarn_parameters, + "longrope": _compute_longrope_parameters, + "llama3": _compute_llama3_parameters, +} + + +def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's gracefully handle it + if "rope_type" not in received_keys and "type" in received_keys: + received_keys -= {"type"} + received_keys.add("rope_type") + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + +def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_yarn_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + +def _validate_longrope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "short_factor", "long_factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + short_factor = rope_scaling.get("short_factor") + if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): + logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + if not len(short_factor) == dim // 2: + logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") + + long_factor = rope_scaling.get("long_factor") + if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): + logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + if not len(long_factor) == dim // 2: + logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") + + # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over + # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is + # unique to longrope (= undesirable) + if hasattr(config, "original_max_position_embeddings"): + logger.warning_once( + "This model has set a `original_max_position_embeddings` field, to be used together with " + "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" + "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " + "as it is compatible with most model architectures." + ) + else: + factor = rope_scaling.get("factor") + if factor is None: + logger.warning("Missing required keys in `rope_scaling`: 'factor'") + elif not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + + +def _validate_llama3_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + if low_freq_factor is None or not isinstance(low_freq_factor, float): + logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + if high_freq_factor is None or not isinstance(high_freq_factor, float): + logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + if high_freq_factor <= low_freq_factor: + logger.warning( + "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" + f"{high_freq_factor} and low_freq_factor={low_freq_factor}" + ) + + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " + f"{original_max_position_embeddings}" + ) + if original_max_position_embeddings >= config.max_position_embeddings: + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " + f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" + ) + + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "linear": _validate_linear_scaling_rope_parameters, + "dynamic": _validate_dynamic_scaling_rope_parameters, + "yarn": _validate_yarn_parameters, + "longrope": _validate_longrope_parameters, + "llama3": _validate_llama3_parameters, +} + + +def rope_config_validation(config: PretrainedConfig): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py new file mode 100644 index 0000000000..a1b413b0e0 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py @@ -0,0 +1,32 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn + +ALL_LAYERNORM_LAYERS = [nn.LayerNorm] \ No newline at end of file diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py new file mode 100644 index 0000000000..3dac4a51c6 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py @@ -0,0 +1,2535 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import copy +import importlib.metadata +import json +import os +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any + +import torch +from packaging import version +from transformers.configuration_utils import PretrainedConfig +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 +from transformers.utils import ( + is_hqq_available, + is_optimum_quanto_available, + is_torch_greater_or_equal, + logging, +) + +if is_hqq_available(): + from hqq.core.quantize import Quantizer as HQQQuantizer + +logger = logging.get_logger(__name__) + + +class Cache: + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + is_compileable = False + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + def get_max_cache_shape(self) -> int | None: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: int | None = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + if self.value_cache[layer_idx].numel(): + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + + +@dataclass +class CacheConfig: + """ + Base class for cache configs + """ + + cache_implementation: None + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: str | os.PathLike): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class QuantizedCacheConfig(CacheConfig): + """ + Configuration class for quantized cache settings. + + Attributes: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original precision. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to perform computations, should be same as the model's device. + """ + + def __init__( + self, + backend: str = "quanto", + nbits: int | None = 4, + axis_key: int | None = 0, + axis_value: int | None = 0, + q_group_size: int | None = 64, + residual_length: int | None = 128, + compute_dtype: torch.dtype | None = torch.float16, + device: str | None = "cpu", + ): + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), + ) + + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) + + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) + + +@dataclass +class StaticCacheConfig(CacheConfig): + """ + Configuration class for static cache settings. + """ + + cache_implementation = "static" + + def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + self.batch_size = batch_size + self.max_cache_len = max_cache_len + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + + if self.batch_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="batch_size", + correct_value="> 0", + found_value=self.batch_size, + ), + ) + + if self.max_cache_len <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="max_cache_len", + correct_value="> 0", + found_value=self.max_cache_len, + ), + ) + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicCache() + ``` + """ + + def __init__(self, _distributed_cache_data: Iterable = None) -> None: + super().__init__() + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + if _distributed_cache_data is not None: + for key_states, value_states in _distributed_cache_data: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + not self.key_cache[ + layer_idx + ].numel() # prefers not t.numel() to len(t) == 0 to export the model + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) + <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def get_max_cache_shape(self) -> int | None: + """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: tuple[tuple[torch.FloatTensor]] | None = None + ) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: list["DynamicCache"]) -> "DynamicCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + key_cache = [ + current.key_cache[idx] for current in splits if current.key_cache[idx].numel() + ] + value_cache = [ + current.value_cache[idx] for current in splits if current.value_cache[idx].numel() + ] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave( + repeats, dim=0 + ) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + +# Utilities for `DynamicCache` <> torch.export support +def _flatten_dynamic_cache( + dynamic_cache: DynamicCache, +): + """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" + if not isinstance(dynamic_cache, DynamicCache): + raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") + + if not is_torch_greater_or_equal_than_2_6: + logger.warning_once( + "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." + ) + + # NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking + dictionary = { + "key_cache": getattr(dynamic_cache, "key_cache"), + "value_cache": getattr(dynamic_cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): + dictionary = { + "key_cache": getattr(dynamic_cache, "key_cache"), + "value_cache": getattr(dynamic_cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def _unflatten_dynamic_cache( + values, + context: torch.utils._pytree.Context, +): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = DynamicCache() + for k, v in dictionary.items(): + setattr(cache, k, v) + return cache + + +def _flatten_dynamic_cache_for_fx(cache, spec): + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree.tree_flatten(dictionary)[0] + + +if is_torch_greater_or_equal("2.3"): + torch.utils._pytree.register_pytree_node( + DynamicCache, + _flatten_dynamic_cache, + _unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, + ) + # TODO (tmanlaibaatar) This won't be needed in torch 2.7. + torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) + + +class OffloadedCache(DynamicCache): + """ + A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. + Useful for generating from models with very long context. + + In addition to the default accelerator stream, where all forward() computations happen, + this class uses another stream, the prefetch stream, which it creates itself. + Since scheduling of operations on separate streams happens independently, this class uses + the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. + The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to + ensure the eviction is scheduled after all computations on that cache are finished. + """ + + def __init__(self) -> None: + if not ( + torch.cuda.is_available() + or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) + ): + raise RuntimeError( + "OffloadedCache can only be used with a GPU" + + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") + ) + + super().__init__() + self.original_device = [] + self.prefetch_stream = None + self.prefetch_stream = ( + torch.Stream() + if is_torch_greater_or_equal("2.7", accept_dev=True) + else torch.cuda.Stream() + ) + self.beam_idx = None # used to delay beam search operations + + def prefetch_layer(self, layer_idx: int): + "Starts prefetching the next layer cache" + if layer_idx < len(self): + with ( + self.prefetch_stream + if is_torch_greater_or_equal("2.7", accept_dev=True) + else torch.cuda.stream(self.prefetch_stream) + ): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to( + device, non_blocking=True + ) + + def evict_previous_layer(self, layer_idx: int): + "Moves the previous layer cache to the CPU" + if len(self) > 2: + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(self) + self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to( + "cpu", non_blocking=True + ) + self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to( + "cpu", non_blocking=True + ) + + def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: + "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." + if layer_idx < len(self): + # Evict the previous layer if necessary + if is_torch_greater_or_equal("2.7", accept_dev=True): + torch.accelerator.current_stream().synchronize() + else: + torch.cuda.current_stream().synchronize() + self.evict_previous_layer(layer_idx) + # Load current layer cache to its original device if not already there + original_device = self.original_device[layer_idx] + self.prefetch_stream.synchronize() + key_tensor = self.key_cache[layer_idx] + value_tensor = self.value_cache[layer_idx] + # Now deal with beam search ops which were delayed + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(original_device) + key_tensor = key_tensor.index_select(0, self.beam_idx) + value_tensor = value_tensor.index_select(0, self.beam_idx) + # Prefetch the next layer + self.prefetch_layer((layer_idx + 1) % len(self)) + return (key_tensor, value_tensor) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Saves the beam indices and reorders the cache when the tensor is back to its device.""" + # We delay this operation until the tensors are back to their original + # device because performing torch.index_select on the CPU is very slow + del self.beam_idx + self.beam_idx = beam_idx.clone() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) < layer_idx: + raise ValueError( + "OffloadedCache does not support model usage where layers are skipped. Use DynamicCache." + ) + elif len(self.key_cache) == layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.original_device.append(key_states.device) + self.evict_previous_layer(layer_idx) + else: + key_tensor, value_tensor = self[layer_idx] + self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError + # if a method is not supposed to be supported in a subclass we should set it to None + from_legacy_cache = None + + to_legacy_cache = None + + +class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + super().__init__() + self._quantized_key_cache: list[torch.Tensor] = [] + self._quantized_value_cache: list[torch.Tensor] = [] + + self.nbits = cache_config.nbits + self.residual_length = cache_config.residual_length + self.q_group_size = cache_config.q_group_size + self.axis_key = cache_config.axis_key + self.axis_value = cache_config.axis_value + self.compute_dtype = cache_config.compute_dtype + self.device = cache_config.device + + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if len(self.key_cache) < layer_idx: + raise ValueError( + "QuantizedCache does not support model usage where layers are skipped. Use DynamicCache." + ) + elif len(self.key_cache) == layer_idx: + self._quantized_key_cache.append( + self._quantize(key_states.contiguous(), axis=self.axis_key) + ) + self._quantized_value_cache.append( + self._quantize(value_states.contiguous(), axis=self.axis_value) + ) + self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) + self.value_cache.append( + torch.zeros(0, dtype=key_states.dtype, device=key_states.device) + ) + keys_to_return, values_to_return = key_states, value_states + else: + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] + values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] + + keys_to_return = torch.cat(keys_to_return, dim=-2) + values_to_return = torch.cat(values_to_return, dim=-2) + if ( + self.key_cache[layer_idx].dim() == 4 + and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length + ): + self._quantized_key_cache[layer_idx] = self._quantize( + keys_to_return.contiguous(), axis=self.axis_key + ) + self._quantized_value_cache[layer_idx] = self._quantize( + values_to_return.contiguous(), axis=self.axis_value + ) + self.key_cache[layer_idx] = torch.zeros( + 0, dtype=key_states.dtype, device=key_states.device + ) + self.value_cache[layer_idx] = torch.zeros( + 0, dtype=key_states.dtype, device=key_states.device + ) + else: + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + return keys_to_return, values_to_return + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is + # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to verify attn_weight shape in some models + return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 + + def _quantize(self, tensor, axis): + """Quantizes a key/value using a defined quantization method.""" + raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") + + def _dequantize(self, q_tensor): + """Dequantizes back the tensor that was quantized by `self._quantize()`""" + raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") + + +class QuantoQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install quanto first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4) + >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + QuantoQuantizedCache() + ``` + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + + if is_optimum_quanto_available(): + optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) + if optimum_quanto_version <= version.parse("0.2.5"): + raise ImportError( + f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." + ) + from optimum.quanto import MaxOptimizer, qint2, qint4 + + if self.nbits not in [2, 4]: + raise ValueError( + f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}" + ) + + if self.axis_key not in [0, -1]: + raise ValueError( + f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}" + ) + + if self.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" + ) + + self.qtype = qint4 if self.nbits == 4 else qint2 + self.optimizer = ( + MaxOptimizer() + ) # hardcode as it's the only one for per-channel quantization + + def _quantize(self, tensor, axis): + # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore + if is_optimum_quanto_available(): + from optimum.quanto import quantize_weight + + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + return qtensor + + def _dequantize(self, qtensor): + return qtensor.dequantize() + + +class HQQQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install hqq first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) + >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HQQQuantizedCache() + ``` + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" + ) + + if self.axis_key not in [0, 1]: + raise ValueError( + f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}" + ) + + if self.axis_value not in [0, 1]: + raise ValueError( + f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}" + ) + + self.quantizer = HQQQuantizer + + def _quantize(self, tensor, axis): + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.device, + compute_dtype=self.compute_dtype, + nbits=self.nbits, + group_size=self.q_group_size, + ) + meta["compute_dtype"] = self.compute_dtype + self.quantizer.cuda( + qtensor, meta=meta, device=self.device + ) # Move to device and cast to dtype + return qtensor, meta + + def _dequantize(self, qtensor): + quant_tensor, meta = qtensor + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor + + +class SinkCache(Cache): + """ + A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to + generate beyond the length of its context window, without losing fluency in the conversation. As it discards past + tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Parameters: + window_length (`int`): + The length of the context window. + num_sink_tokens (`int`): + The number of sink tokens. See the original paper for more information. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SinkCache() + ``` + """ + + is_sliding = True + + def __init__(self, window_length: int, num_sink_tokens: int) -> None: + super().__init__() + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + @staticmethod + def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> torch.Tensor: + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence + original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] + shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] + shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_rerotation_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_cache_shape(self) -> int | None: + """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length.""" + return self.window_length + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, + `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the + rotation as the tokens are shifted. + + Return: + A tuple containing the updated key and value states. + """ + # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models + # with partially rotated position embeddings, like Phi or Persimmon. + if cache_kwargs is None: + cache_kwargs = {} + sin = cache_kwargs.get("sin") + cos = cache_kwargs.get("cos") + partial_rotation_size = cache_kwargs.get("partial_rotation_size") + using_rope = cos is not None and sin is not None + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the sin/cos cache, which holds sin/cos values for all possible positions + if using_rope and layer_idx == 0: + # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove + # after all RoPE models have a llama-like cache utilization. + if cos.dim() == 2: + self._cos_cache = cos + self._sin_cache = sin + elif self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + else: + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : + ] + + # On RoPE models, we need to recompute the Key rotation as the tokens are shifted + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, + self._cos_cache[: self.window_length], + self._sin_cache[: self.window_length], + ) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb( + keys_to_keep, rerotation_cos, rerotation_sin + ) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : + ] + self.value_cache[layer_idx] = torch.cat( + [sink_values, values_to_keep, value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. If you are manually setting the batch size, make sure to take into account the + number of beams if you are running beam search + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ + + is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if max_cache_len is None else max_cache_len + ) + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + self._dtype = dtype + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + self.max_cache_len, + self.head_dim, + ) + device = torch.device(device) if device is not None else None + for idx in range(config.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[idx] + else: + layer_device = device + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class SlidingWindowCache(StaticCache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SlidingWindowCache() + ``` + """ + + is_sliding = True + is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + max_cache_len = min(config.sliding_window, max_cache_len) + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) + if cache_position.shape[0] > self.max_cache_len: + k_out = key_states[:, :, -self.max_cache_len :, :] + v_out = value_states[:, :, -self.max_cache_len :, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones( + self.max_cache_len, dtype=torch.long, device=value_states.device + ).cumsum(0) + cache_position = cache_position.clamp(0, self.max_cache_len - 1) + to_shift = cache_position >= self.max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len + + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + + return k_out, v_out + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def reset(self): + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class EncoderDecoderCache(Cache): + """ + Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and + cross-attention caches. + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") + >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") + + >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") + + >>> # Prepare cache classes for encoder and decoder and pass it to model's forward + >>> self_attention_cache = DynamicCache() + >>> cross_attention_cache = DynamicCache() + >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + EncoderDecoderCache() + ``` + + """ + + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): + super().__init__() + self.self_attention_cache = self_attention_cache + self.cross_attention_cache = cross_attention_cache + self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) + + self.is_updated = {} + for layer_idx in range(len(cross_attention_cache.key_cache)): + self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + + def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.self_attention_cache) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), + self.cross_attention_cache.to_legacy_cache(), + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: tuple[tuple[torch.FloatTensor]] | None = None + ) -> "EncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls( + self_attention_cache=DynamicCache(), + cross_attention_cache=DynamicCache(), + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` + return self.self_attention_cache.get_seq_length(layer_idx) + + def reset(self): + if hasattr(self.self_attention_cache, "reset"): + self.self_attention_cache.reset() + if hasattr(self.cross_attention_cache, "reset"): + self.cross_attention_cache.reset() + elif not hasattr(self.self_attention_cache, "reset") and not hasattr( + self.cross_attention_cache, "reset" + ): + raise ValueError( + "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " + "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " + f"Got {self.self_attention_cache.__str__()} for the self attention cache and " + f"{self.cross_attention_cache.__str__()} for the cross attention cache." + ) + for layer_idx in self.is_updated: + self.is_updated[layer_idx] = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + self.self_attention_cache.reorder_cache(beam_idx) + self.cross_attention_cache.reorder_cache(beam_idx) + + def check_dynamic_cache(self, method: str): + if not ( + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) + ): + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." + ) + + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) + + def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) + + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out + + @classmethod + def from_batch_splits(cls, splits: list["EncoderDecoderCache"]) -> "EncoderDecoderCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() + for idx in range(len(splits[0])): + layer_keys = torch.cat( + [current.self_attention_cache.key_cache[idx] for current in splits], dim=0 + ) + layer_values = torch.cat( + [current.self_attention_cache.value_cache[idx] for current in splits], dim=0 + ) + self_attention_cache.update(layer_keys, layer_values, idx) + + layer_keys = torch.cat( + [current.cross_attention_cache.key_cache[idx] for current in splits], dim=0 + ) + layer_values = torch.cat( + [current.cross_attention_cache.value_cache[idx] for current in splits], dim=0 + ) + cross_attention_cache.update(layer_keys, layer_values, idx) + return cls(self_attention_cache, cross_attention_cache) + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) + + +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (torch.dtype, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert + # ALL changes from the PR that commented the line below when reactivating it. + # is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super().__init__() + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + self._dtype = dtype + self.num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) + + layer_switch = ( + config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 + ) # 2 is for BC + self.is_sliding = torch.tensor( + [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], + dtype=torch.bool, + ) + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + global_cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + max_cache_len, + self.head_dim, + ) + sliding_cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + min(config.sliding_window, max_cache_len), + self.head_dim, + ) + device = torch.device(device) if device is not None and isinstance(device, str) else None + for i in range(config.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[i] + else: + layer_device = device + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + if cache_position.shape[0] > max_cache_len: + k_out = key_states[:, :, -max_cache_len:, :] + v_out = value_states[:, :, -max_cache_len:, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, max_cache_len - 1) + to_shift = cache_position >= max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % max_cache_len + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + return k_out, v_out + + def _static_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + sliding_window = cache_kwargs.get("sliding_window") + + # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used + # when the cache is initialized in the forward pass (e.g. Gemma2) + if self.key_cache[layer_idx].device != key_states.device: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + if self.value_cache[layer_idx].device != value_states.device: + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if sliding_window: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def get_seq_length(self, layer_idx: int | None = 0): + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx != 0: + raise ValueError( + "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " + "Using the `layer_idx` argument is not supported." + ) + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class HybridChunkedCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert + # ALL changes from the PR that commented the line below when reactivating it. + is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.bfloat16, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super().__init__() + if not hasattr(config, "sliding_window") or config.sliding_window is None: + self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) + else: + self.sliding_window = config.sliding_window + self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self._dtype = dtype + + if hasattr(config.get_text_config(), "no_rope_layers"): + self.is_sliding = config.no_rope_layers + else: + layer_switch = getattr(config, "sliding_window_pattern", 2) + self.is_sliding = [ + bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers) + ] + + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] + + def initialise_cache_layer(self, layer_idx, key_states): + if len(self.key_cache) > layer_idx: + return + + num_key_value_heads = key_states.shape[1] + device = key_states.device + global_cache_shape = ( + self.max_batch_size, + num_key_value_heads, + self.max_cache_len, + self.head_dim, + ) + sliding_cache_shape = ( + self.max_batch_size, + num_key_value_heads, + self.sliding_window, + self.head_dim, + ) + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + cumulative_length = self.cumulative_length[layer_idx] + # Update it now that we saved the value above + self.cumulative_length[layer_idx] += key_states.shape[-2] + is_full = cumulative_length >= max_cache_len + if is_full: + full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) + # Fast decoding path -> here as the effective size is still sliding window, it is extremely important + # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed adress + # in memory (the values are the same as the full states, but not the address!!) + if key_states.shape[-2] == 1: + self.key_cache[layer_idx].copy_(full_key_states) + self.value_cache[layer_idx].copy_(full_value_states) + return self.key_cache[layer_idx], self.value_cache[layer_idx] + elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: + # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) + if cumulative_length == 0: + full_key_states = key_states + full_value_states = value_states + else: + full_key_states = torch.cat( + (k_out[:, :, :cumulative_length, :], key_states), dim=-2 + ) + full_value_states = torch.cat( + (v_out[:, :, :cumulative_length, :], value_states), dim=-2 + ) + else: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) + self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return full_key_states, full_value_states + + def _static_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + self.initialise_cache_layer(layer_idx, key_states) + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if self.is_sliding[layer_idx]: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def get_seq_length(self, layer_idx: int | None = 0): + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx != 0: + raise ValueError( + "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " + "Using the `layer_idx` argument is not supported." + ) + if len(self.key_cache) == 0: + return 0 + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: torch.device | str | None = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to( + device=conv_state.device, dtype=conv_state.dtype + ) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() + + +class OffloadedStaticCache(StaticCache): + """ + Static cache class to be used with `torch.compile(model)` that offloads to the CPU or + another device. + + Args: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize + the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`Union[str, torch.device]`): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (`torch.dtype`, *optional*): + The default `dtype` to use when initializing the cache. + offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): + The device to offload to. Defaults to CPU. + layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is splitted between differents gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None, + device: str | torch.device, + dtype: torch.dtype | None = None, + offload_device: str | torch.device = torch.device("cpu"), + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super(Cache, self).__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if max_cache_len is None else max_cache_len + ) + self.device = ( + torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) + ) + self.offload_device = torch.device(offload_device) + self._dtype = dtype if dtype is not None else torch.float32 + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + + cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) + + # Create offloaded CPU tensors. + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + for i in range(config.num_hidden_layers): + # First layer is always on-device. + device = self.device if i == 0 else self.offload_device + + key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) + + self.key_cache.append(key_cache) + self.value_cache.append(value_cache) + + # Create device tensors. + self._device_key_cache: list[torch.Tensor] = [] + self._device_value_cache: list[torch.Tensor] = [] + + for i in range(2): + key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) + + self._device_key_cache.append(key_cache) + self._device_value_cache.append(value_cache) + + # For backwards compatibility. + # TODO(gante): Remove this. + self._seen_tokens = 0 + + # Create new CUDA stream for parallel prefetching. + self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the + `cache_position` input to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + if layer_idx == 0: + # Update seen tokens. + # TODO(gante): Remove this. + self._seen_tokens += key_states.shape[-2] + + # Always there. + k_out = self.key_cache[0] + v_out = self.value_cache[0] + else: + # Wait for prefetch stream. + if self._prefetch_stream is not None: + torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) + + k_out = self._device_key_cache[layer_idx & 1] + v_out = self._device_value_cache[layer_idx & 1] + + self._prefetch_layer(layer_idx + 1) + + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + + # Copy the values to the offloaded device as well. + if layer_idx == 0: + self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) + self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does + # explicitly an in-place operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS + # device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # Copy the values to the offloaded device as well. + if layer_idx != 0: + cache_position = cache_position.to(self.offload_device) + key_states = key_states.to(self.offload_device) + value_states = value_states.to(self.offload_device) + + try: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS + # device. + self.key_cache[layer_idx][:, :, cache_position] = key_states + self.value_cache[layer_idx][:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + + # TODO(gante): Remove this. + return self._seen_tokens + + def get_max_cache_shape(self) -> int | None: + """Returns the maximum sequence length of the cached states.""" + + return self.max_cache_len + + def reset(self) -> None: + """Resets the cache values while preserving the objects.""" + + # For backwards compatibility. + # TODO(gante): Remove this. + self._seen_tokens = 0 + + # Zero out cache. + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address. + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + @property + def seen_tokens(self) -> int: + # For backwards compatibility. + # TODO(gante): Remove this. + return self._seen_tokens + + def _create_key_value_cache_tensors( + self, shape: tuple[int, ...], device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor]: + """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static + addresses for non-CPU tensors. + + Args: + shape (`Tuple[int, ...]`): Shape. + device (`torch.device`): Device. + + Returns: + Key and value cache tensors as a tuple. + """ + + is_cpu_device = device == torch.device("cpu") + + key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) + value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) + + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(key_cache) + torch._dynamo.mark_static_address(value_cache) + + return key_cache, value_cache + + def _prefetch_layer(self, layer_idx: int) -> None: + """Prefetch a layer to the device. Needs to be called in order of layer indices.""" + + # Don't fetch layers that do not exist. + if layer_idx >= len(self.key_cache): + return + + # Alternate between two on-device caches. + if self._prefetch_stream is not None: + with torch.cuda.stream(self._prefetch_stream): + self._prefetch_layer_in_context(layer_idx) + else: + self._prefetch_layer_in_context(layer_idx) + + def _prefetch_layer_in_context(self, layer_idx: int) -> None: + """Performs the actual copy of the layer to device cache.""" + + self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) + self._device_value_cache[layer_idx & 1].copy_( + self.value_cache[layer_idx], non_blocking=True + ) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py new file mode 100644 index 0000000000..7dc65a0923 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py @@ -0,0 +1,447 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Llama4VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a + Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + num_hidden_layers (`int`, *optional*, defaults to 34): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + vision_output_dim (`int`, *optional*, defaults to 7680): + Dimensionality of the vision model output. Includes output of transformer + encoder with intermediate layers and global transformer encoder. + image_size (`int`, *optional*, defaults to 448): + The size (resolution) of each image *tile*. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + vision_feature_layer (``, *optional*, defaults to -1): TODO + vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO + projector_input_dim (`int`, *optional*, defaults to 4096): TODO + projector_output_dim (`int`, *optional*, defaults to 4096): TODO + multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO + projector_dropout (`int`, *optional*, defaults to 0.0): TODO + attention_dropout (`int`, *optional*, defaults to 0.0): TODO + rope_theta (`int`, *optional*, defaults to 10000): TODO + """ + + base_model_tp_plan = { + "model.layers.*.self_attn.q_proj": "colwise", + "model.layers.*.self_attn.k_proj": "colwise", + "model.layers.*.self_attn.v_proj": "colwise", + "model.layers.*.self_attn.o_proj": "rowwise", + "vision_adapter.mlp.fc1": "colwise", + "vision_adapter.mlp.fc2": "rowwise", + "patch_embedding.linear": "colwise_rep", + } + model_type = "llama4_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size: int = 768, + hidden_act: str = "gelu", + num_hidden_layers: int = 34, + num_attention_heads: int = 16, + num_channels: int = 3, + intermediate_size: int = 5632, + vision_output_dim: int = 7680, + image_size: int = 448, + patch_size: int = 14, + norm_eps: float = 1e-5, + vision_feature_layer=-1, + vision_feature_select_strategy="default", + initializer_range: float = 0.02, + pixel_shuffle_ratio=0.5, + projector_input_dim=4096, + projector_output_dim=4096, + multi_modal_projector_bias=False, + projector_dropout=0.0, + attention_dropout=0.0, + rope_theta=10000, + **kwargs, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.num_channels = num_channels + self.intermediate_size = intermediate_size + self.image_size = image_size + self.vision_output_dim = vision_output_dim + self.patch_size = patch_size + self.norm_eps = norm_eps + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.pixel_shuffle_ratio = pixel_shuffle_ratio + self.projector_input_dim = projector_input_dim + self.projector_output_dim = projector_output_dim + self.multi_modal_projector_bias = multi_modal_projector_bias + self.projector_dropout = projector_dropout + self.attention_dropout = attention_dropout + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy + self.rope_theta = rope_theta + super().__init__(**kwargs) + + +class Llama4TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a + Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 202048): + Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented + by the `inputs_ids` passed when calling [`Llama4TextModel`]. + hidden_size (`int`, *optional*, defaults to 5120): + Dimensionality of the embeddings and hidden states. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 40): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If not + specified, will default to `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 128): TODO + hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. + pad_token_id (`int`, *optional*, defaults to 128004): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the beginning of sentence token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to `500000.0`): + The base period of the RoPE embeddings. + attention_dropout (`int`, *optional*, defaults to 0.0): TODO + num_experts_per_tok (`int`, *optional*, defaults to 1): TODO + num_local_experts (`int`, *optional*, defaults to 16): TODO + moe_layers (`int`, *optional*): TODO + interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO + use_qk_norm (`int`, *optional*, defaults to `True`): TODO + output_router_logits (`int`, *optional*, defaults to `False`): TODO + router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO + router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + + no_rope_layers (`int`, *optional*): TODO + no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO + attention_chunk_size (`int`, *optional*, defaults to 8192): + + attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO + floor_scale (`int`, *optional*, defaults to 8192): TODO + attn_scale (`int`, *optional*, defaults to 0.1): TODO + cache_implementation (``, *optional*, defaults to `"hybrid"`): + + Example: + """ + + model_type = "llama4_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.input_layernorm.weight": "sequence_parallel", + "layers.*.post_attention_layernorm.weight": "sequence_parallel", + "norm.weight": "sequence_parallel", + "layers.*.feed_forward.shared_expert.gate_proj": "local_colwise", + "layers.*.feed_forward.shared_expert.up_proj": "local_colwise", + "layers.*.feed_forward.shared_expert.down_proj": "local_rowwise", + "layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear + "layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear + "layers.*.feed_forward.experts": "local", + "layers.*.feed_forward.gate_proj": "local_colwise", + "layers.*.feed_forward.up_proj": "local_colwise", + "layers.*.feed_forward.down_proj": "local_rowwise", + "layers.*.feed_forward": "gather", + } + + def __init__( + self, + vocab_size=202048, + hidden_size=5120, + intermediate_size=8192, + intermediate_size_mlp=16384, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=500000, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=16, + moe_layers=None, + interleave_moe_layer_step=1, + use_qk_norm=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + rope_scaling=None, + no_rope_layers=None, + no_rope_layer_interval=4, + attention_chunk_size=8192, + attn_temperature_tuning=4, + floor_scale=8192, + attn_scale=0.1, + cache_implementation="hybrid", + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.attn_temperature_tuning = attn_temperature_tuning + self.attn_scale = attn_scale + self.floor_scale = floor_scale + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.intermediate_size_mlp = intermediate_size_mlp + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.rope_scaling = rope_scaling + self.attention_bias = False + self.cache_implementation = cache_implementation + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.use_qk_norm = use_qk_norm + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + default_no_rope_layers = [ + int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers) + ] + + # no_rope_layers == [] is invalid as we cannot have 0 layers + self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers + + self.interleave_moe_layer_step = interleave_moe_layer_step + self.moe_layers = ( + moe_layers + if moe_layers is not None + else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step)) + ) + self.attention_chunk_size = attention_chunk_size + + +class Llama4Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an + Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vision_config (`Llama4VisionConfig`, *optional*): + The Llama4 Vision config. + text_config (`Llama4TextConfig`, *optional*): + The Llama4 Text config. + boi_token_index (`int`, *optional*, defaults to 200080): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 200081): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 200092): + The image token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + + ```python + >>> from transformers import Llama4Model, Llama4Config + + >>> # Initializing a Llama4 7B style configuration + >>> configuration = Llama4Config() + + >>> # Initializing a model from the Llama4 7B style configuration + >>> model = Llama4Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama4" + sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig} + base_model_tp_plan = { + "multi_modal_projector.linear_1": "colwise_rep", + } + + def __init__( + self, + vision_config=None, + text_config=None, + boi_token_index=200080, + eoi_token_index=200081, + image_token_index=200092, + tie_word_embeddings=False, + **kwargs, + ): + if vision_config is None: + self.vision_config = Llama4VisionConfig() + logger.info("vision_config is None, using default llama4 vision config") + elif isinstance(vision_config, dict): + self.vision_config = Llama4VisionConfig(**vision_config) + elif isinstance(vision_config, Llama4VisionConfig): + self.vision_config = vision_config + + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + if text_config is None: + self.text_config = Llama4TextConfig() + logger.info("text_config is None, using default llama4 text config") + elif isinstance(text_config, dict): + self.text_config = Llama4TextConfig(**text_config) + elif isinstance(text_config, Llama4TextConfig): + self.text_config = text_config + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"] diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py new file mode 100644 index 0000000000..b17883628f --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.processing_utils import Unpack +from transformers.utils import ( + is_torch_flex_attn_available, + logging, +) +from .transformers_4_51_3__configuration_llama4 import Llama4TextConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from transformers.integrations.flex_attention import make_flex_block_causal_mask + +logger = logging.get_logger(__name__) + + +class Llama4TextL2Norm(torch.nn.Module): + def __init__(self, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) + + def extra_repr(self): + return f"eps={self.eps}" + + +class Llama4TextRotaryEmbedding(nn.Module): + def __init__(self, config: Llama4TextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + self.rope_type = "llama3" if config.rope_scaling is not None else "default" + + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + freqs_cis = freqs_cis * self.attention_scaling + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # print(f"{module.layer_idx=} {module.num_key_value_groups=}") + # print(f"{module.layer_idx=} {module.head_dim=}") + # print(f"{module.layer_idx=} {module.training=}") + # print(f"{scaling=}") + # print(f"{dropout=}") + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Llama4TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Llama4TextConfig, layer_idx, use_rope: bool): # we added use_rope to not be dependent on the layer index + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attn_scale = config.attn_scale + self.floor_scale = config.floor_scale + self.attn_temperature_tuning = config.attn_temperature_tuning + self.attention_dropout = config.attention_dropout + self.is_causal = True + # self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers + self.use_rope = use_rope + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + if self.config.use_qk_norm and self.use_rope: + self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.use_rope: # the 16E model skips rope for long context on certain layers + query_states, key_states = apply_rotary_emb( + query_states, key_states, position_embeddings.to(query_states.device) + ) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers + if self.attn_temperature_tuning and not self.use_rope: + device = query_states.device + attn_scales = ( + torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 + ).to(device) + attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 + query_states = (query_states * attn_scales).to(query_states.dtype) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # print(f"{self.layer_idx=} {cache_position=} {attention_mask=}") + # print(f"{self.layer_idx=} {query_states.flatten()[:10]=}") + # print(f"{self.layer_idx=} {key_states.flatten()[:10]=}") + # print(f"{self.layer_idx=} {value_states.flatten()[:10]=}") + # print(f"{self.layer_idx=} {kwargs=}") + # print(f"{self.layer_idx=} {attention_interface=}") + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py new file mode 100644 index 0000000000..9acc27eb9f --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +from copy import deepcopy +from typing import Any + +import torch +from transformers.cache_utils import ( + Cache, # used to let GenerationMixin know that we use a Cache object +) + +from .configuration_decilm import DeciLMConfig +from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2 +from .transformers_4_44_2__cache_utils import SinkCache, SlidingWindowCache, StaticCache +from .transformers_4_51_3__cache_utils import HybridChunkedCache + +LayerIndex = tuple[ + int, ... +] # supports both regular transformer blocks and parallel transformer multi-blocks + + +class VariableCache(Cache_4_44_2, Cache): + """ + A Cache object that supports a different Cache implementation for every layer, + including layers without any kv-cache. + Implemented using a list of Cache objects, each represents a "model" with 1 layer. + The default implementation for the layer caches is StaticCache. + The cache of each layer is allocated to the same gpu as the layer itself. + """ + + def __init__( + self, + *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions + config: DeciLMConfig, + batch_size: int | None = None, + max_cache_len: int | None = None, + dtype: torch.dtype = torch.get_default_dtype(), + max_batch_size: int | None = None, + **kwargs, + ) -> None: + Cache_4_44_2.__init__(self) + + self.config = deepcopy(config) + self.max_batch_size = batch_size or max_batch_size + self.batch_size = self.max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if (max_cache_len is None) else max_cache_len + ) + self.dtype = dtype + + self.layer_caches: dict[LayerIndex, Cache_4_44_2] = {} + self.layer_devices: dict[LayerIndex, torch.device] = {} + + def __repr__(self): + return ( + f"VariableCache:\n" + f"==============\n" + f"max_batch_size={self.max_batch_size}\n" + f"batch_size={self.batch_size}\n" + f"max_cache_len={self.max_cache_len}\n" + f"dtype={self.dtype}\n" + f"layer_caches={self.layer_caches}\n" + f"layer_devices={self.layer_devices}\n" + f"==============\n" + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int | LayerIndex, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(layer_idx, int): + layer_idx = _int_to_layer_index(layer_idx) + + if layer_idx not in self.layer_caches: + self.layer_devices[layer_idx] = key_states.device + self._init_layer_cache(layer_idx) + + layer_cache = self.layer_caches[layer_idx] + assert layer_cache is not None, ( + f"Trying to update the cache of a cache-less layer: {layer_idx=}" + ) + + k_out, v_out = layer_cache.update( + key_states=key_states, value_states=value_states, layer_idx=0, cache_kwargs=cache_kwargs + ) + + input_seq_len = key_states.shape[2] # [batch_size, num_kv_heads, seq_len, hidden_size] + cache_seq_len = self.get_seq_length(layer_idx) + seq_len = max(input_seq_len, cache_seq_len) + + k_out = k_out[:, :, :seq_len, :] + v_out = v_out[:, :, :seq_len, :] + return k_out, v_out + + def _init_layer_cache(self, layer_idx: LayerIndex) -> None: + block_config = self.config.get_block_config(layer_idx) + attention_config = block_config.attention + + if attention_config.no_op or attention_config.replace_with_linear: + return None + + device = self.layer_devices[layer_idx] + assert device is not None, f"Trying to init layer cache for {layer_idx=} without device" + + config = deepcopy(self.config) + config.num_hidden_layers = 1 + config.num_key_value_heads = ( + self.config.num_attention_heads // attention_config.n_heads_in_group + ) + + if attention_config.is_llama4: + attention_chunk_size = attention_config.llama4.attention_chunk_size + is_chunked = attention_chunk_size is not None + config.no_rope_layers = [int(is_chunked)] + config.attention_chunk_size = ( + attention_chunk_size if is_chunked else config.get_min_attention_chunk_size() + ) + self.layer_caches[layer_idx] = HybridChunkedCache( + config=config, + max_batch_size=self.max_batch_size, + max_cache_len=self.max_cache_len, + dtype=self.dtype, + ) + return + + if attention_config.window_length is not None: + if not attention_config.is_sink: + config.sliding_window = attention_config.window_length + self.layer_caches[layer_idx] = SlidingWindowCache( + config=config, + max_batch_size=self.max_batch_size, + max_cache_len=self.max_cache_len, + device=device, + dtype=self.dtype, + ) + return + elif not attention_config.unshifted_sink: + self.layer_caches[layer_idx] = SinkCache( + window_length=attention_config.window_length, + num_sink_tokens=attention_config.num_sink_tokens, + ) + return + + self.layer_caches[layer_idx] = StaticCache( + config=config, + max_batch_size=self.max_batch_size, + max_cache_len=self.max_cache_len, + device=device, + dtype=self.dtype, + ) + + def _get_arbitrary_cache(self) -> Cache_4_44_2: + if len(self.layer_caches) == 0: + raise NoCacheFoundError() + layer_cache = next(iter(self.layer_caches.values())) + return layer_cache + + def get_seq_length(self, layer_idx: int | LayerIndex | None = 0) -> int: + """default 0 to match standard HF implementation""" + if (layer_idx is None) or ( + layer_idx == 0 and _int_to_layer_index(0) not in self.layer_caches + ): + try: + layer_cache = self._get_arbitrary_cache() + return layer_cache.get_seq_length() + except NoCacheFoundError: + return 0 + + if isinstance(layer_idx, int): + layer_idx = _int_to_layer_index(layer_idx) + + layer_cache = self.layer_caches[layer_idx] + return layer_cache.get_seq_length() + + def get_max_length(self) -> int | None: + """Returns the maximum sequence length of the cached states.""" + return self.max_cache_len + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def reset(self): + for layer_idx, layer_cache in self.layer_caches.items(): + if hasattr(layer_cache, "reset"): + layer_cache.reset() + else: + self.layer_caches[layer_idx] = None + self.layer_devices[layer_idx] = None + # self._init_layer_cache(layer_idx) + + +class NoCacheFoundError(Exception): + pass + + +def _int_to_layer_index(layer_idx: int) -> LayerIndex: + return (layer_idx,) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py new file mode 100644 index 0000000000..4c8f86cdbc --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + + +def _apply_rotary_emb_torch( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: int | float) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048 +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +def _yarn_find_correction_range( + low_rot: int, high_rot: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048 +) -> tuple[int, int]: + low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, dtype: torch.dtype) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings + ) + # print(f"low: {low}, high: {high}") + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache From 50a580c2276238167822d0060327dcad01ca7159 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 12 Nov 2025 23:56:18 +0100 Subject: [PATCH 08/62] Compress tutorial (PoC) (#492) ## What does this PR do? Compress tutorial (PoC) + compress cli app. --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: Liana Mikaelyan Signed-off-by: Liana Mikaelyan <45925959+LianaMikael@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Liana Mikaelyan Co-authored-by: Liana Mikaelyan <45925959+LianaMikael@users.noreply.github.com> --- examples/compress/README.md | 200 ++++++++++++++++++ .../Llama-3_1-8B.yaml | 110 ++++++++++ .../llama-3_1-8B_pruneffn_memory.yaml | 21 ++ .../pruning/attn_pruning.yaml | 16 ++ .../pruning/ffn_pruning.yaml | 12 ++ .../pruning/hidden_dim_pruning.yaml | 15 ++ .../pruning/pruning_defaults.yaml | 32 +++ .../validate_model_defaults.yaml | 15 ++ .../validate_solutions_defaults.yaml | 10 + examples/compress/main.py | 164 ++++++++++++++ examples/pruning/README.md | 2 + modelopt/torch/_compress/__init__.py | 15 ++ modelopt/torch/_compress/dataset/__init__.py | 15 ++ .../_compress/dataset/prepare_dataset.py | 64 ++++++ .../nas/plugins/compress_nas_plugin.py | 23 +- modelopt/torch/_compress/tools/logger.py | 166 +++++++++++++++ setup.py | 2 + .../torch/_compress/test_compress.py | 7 +- 18 files changed, 879 insertions(+), 10 deletions(-) create mode 100644 examples/compress/README.md create mode 100644 examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml create mode 100644 examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml create mode 100644 examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml create mode 100644 examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml create mode 100644 examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml create mode 100644 examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml create mode 100644 examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml create mode 100644 examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml create mode 100644 examples/compress/main.py create mode 100644 modelopt/torch/_compress/__init__.py create mode 100644 modelopt/torch/_compress/dataset/__init__.py create mode 100644 modelopt/torch/_compress/dataset/prepare_dataset.py create mode 100644 modelopt/torch/_compress/tools/logger.py diff --git a/examples/compress/README.md b/examples/compress/README.md new file mode 100644 index 0000000000..a4881150d0 --- /dev/null +++ b/examples/compress/README.md @@ -0,0 +1,200 @@ +# Compress Algorithm Tutorial + +This tutorial demonstrates how to compress large language models using the Compress algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). +This tutorial demonstrates how to compress large language models using the compress algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). +The goal of the algorithm it to find the most optimal modifications to MLP and attention layers of the model, resulting in a heterogeneous model architecture. +The supported modifications are: + +- `ffn_intermediate_size`: different FFN intermediate sizes +- `attention op/noop`: complete removal of attention layers + +To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements. + +In this example, we compress the [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. + +## Environment + +- Install TensorRT-Model-Optimizer in editable mode with the corresponding dependencies: + +```bash +pip install -e .[hf,compress] +``` + +- For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use s single GPU. + +## Compress the Model + +1. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file. + + **_NOTE:_** + How to choose `intermediate_size_list`? + The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a 256) to avoid issues with tensor operations in subsequent steps. + + Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` GiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements. + +2. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). + + dataset split: "code", "math", "stem", "chat", excluding reasoning samples (2.62GB) + + ```bash + python -m modelopt.torch._compress.dataset.prepare_dataset --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 --output_dir path/to/Nemotron-Post-Training-Dataset-v2 + ``` + +3. Run the compression script. + + ```bash + torchrun --nproc_per_node 2 examples/compress/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Compress Progress" + ``` + + This will save the full output to `log.txt` and display the following progress on screen: + + ```bash + [2025-11-02 12:06:34][rank-0][main.py:71] Compress Progress 1/8: starting compression pipeline + [2025-11-02 12:06:45][rank-0][compress_nas_plugin.py:123] Compress Progress 2/8: converting model from HF to DeciLM (single-gpu) + [2025-11-02 12:07:07][rank-0][compress_nas_plugin.py:132] Compress Progress 3/8: scoring pruning activations (multi-gpu) + [2025-11-02 12:11:36][rank-0][compress_nas_plugin.py:137] Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu) + [2025-11-02 12:12:20][rank-0][compress_nas_plugin.py:217] Compress Progress 5/8: building replacement library and subblock statistics (single-gpu) + [2025-11-02 12:12:21][rank-0][compress_nas_plugin.py:222] Compress Progress 6/8: calculating one block scores (multi-gpu) + [2025-11-02 12:50:41][rank-0][compress_nas_plugin.py:226] Compress Progress 7/8: running MIP and realizing models (multi-gpu) + [2025-11-02 12:52:34][rank-0][main.py:115] Compress Progress 8/8: compression pipeline completed (multi-gpu) + ``` + + Once the process is complete, the resulting network architecture will be recorded in `log.txt` for your review: + + ```bash + ... + block_0: attention gqa_4 ffn intermediate_14336 + block_1: attention gqa_4 ffn intermediate_14336 + block_2: attention gqa_4 ffn intermediate_14336 + block_3: attention gqa_4 ffn intermediate_14336 + block_4: attention gqa_4 ffn intermediate_14336 + block_5: attention gqa_4 ffn intermediate_14336 + block_6: attention gqa_4 ffn intermediate_14336 + block_7: attention gqa_4 ffn intermediate_14336 + block_8: attention gqa_4 ffn intermediate_14336 + block_9: attention gqa_4 ffn intermediate_14336 + block_10: attention gqa_4 ffn intermediate_14336 + block_11: attention gqa_4 ffn intermediate_14336 + block_12: attention gqa_4 ffn intermediate_14336 + block_13: attention gqa_4 ffn intermediate_14336 + block_14: attention gqa_4 ffn intermediate_14336 + block_15: attention gqa_4 ffn intermediate_14336 + block_16: attention gqa_4 ffn intermediate_14336 + block_17: attention no_op ffn intermediate_14336 + block_18: attention no_op ffn intermediate_14336 + block_19: attention no_op ffn intermediate_14336 + block_20: attention no_op ffn intermediate_14336 + block_21: attention no_op ffn intermediate_14336 + block_22: attention no_op ffn intermediate_14336 + block_23: attention no_op ffn intermediate_14336 + block_24: attention no_op ffn intermediate_14336 + block_25: attention no_op ffn intermediate_14336 + block_26: attention no_op ffn intermediate_14336 + block_27: attention no_op ffn intermediate_14336 + block_28: attention no_op ffn intermediate_14336 + block_29: attention gqa_4 ffn intermediate_14336 + block_30: attention gqa_4 ffn intermediate_14336 + block_31: attention gqa_4 ffn intermediate_14336 + + [2025-11-02 04:53:11,332]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 75796.4140625, 'stats.ffn_num_params': 5637275648, 'stats.num_kv_heads': 160, 'stats.kv_cache_memory_mib': 61440.0, 'stats.ffn_memory_mib': 10752.25, 'stats.attention_memory_mib': 63040.15625, 'stats.attention_num_params': 838942720, 'stats.num_params': 7526895616, 'stats.has_attention': 20, 'stats.has_ffn': 32} + ... + ################################################################ + validate_model_and_extract_token_probs(model_name='teacher') + ################################################################ + ... + Average losses = {'lm_loss': 1.118250765837729, 'token_accuracy_top_1': 0.7331905364990234, 'token_accuracy_top_5': 0.9094219207763672, 'token_accuracy_top_10': 0.9423646926879883} + ... + ################################################################ + validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True) + ################################################################ + .... + Average losses = {'lm_loss': 1.7577573340386152, 'token_accuracy_top_1': 0.6225490570068359, 'token_accuracy_top_5': 0.846257209777832, 'token_accuracy_top_10': 0.8987817764282227} + ``` + + 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). Let's rerun MIP search aiming for 15% memory reduction. + +## Re-run MIP Search with different constraints + +If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag. +This assumes pruning, replacement library building, NAS scoring, and subblock stats calculation have already been completed. + +For example, let's set `target_memory: 96_000` in `llama-3_1-8B_pruneffn_memory.yaml`. + +```bash +torchrun --nproc_per_node 2 examples/compress/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Compress Progress" +``` + +This will generate the following network architecture (see `log.txt`): + +```bash +block_0: attention gqa_4 ffn intermediate_14336 +block_1: attention gqa_4 ffn intermediate_14336 +block_2: attention gqa_4 ffn intermediate_14336 +block_3: attention gqa_4 ffn intermediate_14336 +block_4: attention gqa_4 ffn intermediate_14336 +block_5: attention gqa_4 ffn intermediate_14336 +block_6: attention gqa_4 ffn intermediate_14336 +block_7: attention gqa_4 ffn intermediate_14336 +block_8: attention gqa_4 ffn intermediate_14336 +block_9: attention gqa_4 ffn intermediate_14336 +block_10: attention gqa_4 ffn intermediate_14336 +block_11: attention gqa_4 ffn intermediate_14336 +block_12: attention gqa_4 ffn intermediate_14336 +block_13: attention gqa_4 ffn intermediate_14336 +block_14: attention gqa_4 ffn intermediate_14336 +block_15: attention gqa_4 ffn intermediate_14336 +block_16: attention gqa_4 ffn intermediate_14336 +block_17: attention gqa_4 ffn intermediate_14336 +block_18: attention no_op ffn intermediate_14336 +block_19: attention no_op ffn intermediate_14336 +block_20: attention no_op ffn intermediate_14336 +block_21: attention gqa_4 ffn intermediate_14336 +block_22: attention no_op ffn intermediate_14336 +block_23: attention no_op ffn intermediate_14336 +block_24: attention no_op ffn intermediate_14336 +block_25: attention gqa_4 ffn intermediate_14336 +block_26: attention gqa_4 ffn intermediate_14336 +block_27: attention gqa_4 ffn intermediate_14336 +block_28: attention gqa_4 ffn intermediate_14336 +block_29: attention gqa_4 ffn intermediate_14336 +block_30: attention gqa_4 ffn intermediate_14336 +block_31: attention gqa_4 ffn intermediate_14336 + +[2025-11-02 12:50:42,024]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 94708.4609375, 'stats.has_ffn': 32, 'stats.ffn_memory_mib': 10752.25, 'stats.kv_cache_memory_mib': 79872.0, 'stats.attention_num_params': 1090625536, 'stats.ffn_num_params': 5637275648, 'stats.has_attention': 26, 'stats.num_params': 7778578432, 'stats.attention_memory_mib': 81952.203125, 'stats.num_kv_heads': 208} +... +################################################################ +validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True) +################################################################ +Average losses = {'lm_loss': 1.2425934937782586, 'token_accuracy_top_1': 0.703862190246582, 'token_accuracy_top_5': 0.8954982757568359, 'token_accuracy_top_10': 0.9336576461791992 +``` + +On the other hand, if you set `target_memory: 28_000`, you'll observe that the intermediate FFN sizes are significantly reduced in certain layers (see `log.txt` for details): + +```bash +block_5: attention no_op ffn intermediate_11520 +block_6: attention no_op ffn intermediate_14336 +block_7: attention no_op ffn intermediate_8704 +block_8: attention no_op ffn intermediate_14336 +block_9: attention no_op ffn intermediate_3072 +block_10: attention no_op ffn intermediate_11520 +block_11: attention no_op ffn intermediate_11520 +block_12: attention no_op ffn intermediate_11520 +block_13: attention no_op ffn intermediate_11520 +block_14: attention no_op ffn intermediate_3072 +``` + +## Evaluation + +Once the model is ready, you can evaluate it using [Language Model Evaluation Harness](https://pypi.org/project/lm-eval/). For example, run the following to evaluate the model on [Massive Multitask Language Understanding](https://huggingface.co/datasets/cais/mmlu) benchmark. + +```bash +lm_eval --model hf \ + --model_args pretrained=path/to/model,dtype=bfloat16,trust_remote_code=true,parallelize=True \ + --tasks mmlu \ + --num_fewshot 5 \ + --batch_size 4 +``` + +## Advanced usage + +Modify `path/to/Llama-3_1-8B yaml` file for advanced compression scenarios. diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml new file mode 100644 index 0000000000..70b5304c5b --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 10 # default is 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml new file mode 100644 index 0000000000..cfd7f93e81 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml @@ -0,0 +1,21 @@ +defaults: + - Llama-3_1-8B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 96_000 # 96 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +pruning: + intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..96a8ca72e4 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,12 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..5d5307b9c7 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,32 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_outpt_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..572331a84f --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/compress/main.py b/examples/compress/main.py new file mode 100644 index 0000000000..c8b287fccd --- /dev/null +++ b/examples/compress/main.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Main script for running the compress algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146). + +This script provides two modes: +1. Default mode: Runs the full compression pipeline +2. MIP-only mode: Runs only the MIP search and realize models phase + +Usage: + # Full compression pipeline + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml + + # Only MIP search and realize models phase + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only +""" + +import argparse +import datetime +from pathlib import Path + +import mip_and_realize_models +import torch +from puzzle_tools.hydra_utils import register_hydra_resolvers + +import modelopt.torch.nas as mtn +from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel +from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.logger import mprint +from tests.utils.test_utils import initialize_hydra_config_for_dir + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Compress large language models using the Compress algorithm (based on Puzzle paper https://arxiv.org/abs/2411.19146)" + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the main config YAML file (e.g., ./configs/llama_3.2_1B_pruneffn_memory.yaml)", + ) + parser.add_argument( + "--mip-only", + action="store_true", + help="Run only the MIP search and realize models phase (skip pruning and NAS scoring)", + ) + + return parser.parse_args() + + +def run_full_compress(hydra_config_path: str): + """Run the full compression pipeline. + + Args: + config_path: Path to the YAML configuration file + """ + mprint("Compress Progress 1/8: starting compression pipeline") + with NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)): + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # Convert model (convert from HF to DeciLM, score pruning activations, + # prune the model and save pruned checkpoints) + input_model = CompressModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(hydra_cfg.puzzle_dir), + "input_model_path": hydra_cfg.input_hf_model_path, + "hydra_config_dir": hydra_config_dir, + "hydra_config_name": hydra_config_name, + "dataset_path": str(hydra_cfg.dataset_path), + }, + ) + ], + ) + + # Run NAS search (build replacement library and compute stats, + # compute one block scores, run MIP and realize models) + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") + + +def run_mip_only(hydra_config_path: str): + """Run only the MIP search and realize models phase. + + This assumes that pruning, replacement library building, NAS scoring, and subblock stats calculation + have already been completed. + + Args: + hydra_config_path: Path to the YAML configuration file + """ + + with NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) as runtime: + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API + mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) + + mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") + + +def main(): + args = parse_args() + + if args.mip_only: + run_mip_only(hydra_config_path=args.config) + else: + run_full_compress(hydra_config_path=args.config) + + +if __name__ == "__main__": + main() diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 3efa9eb79b..54f7322b15 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -23,6 +23,8 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar +For more advanced pruning strategies, such as the [Puzzle methodology](https://arxiv.org/pdf/2411.19146), please see [Puzzle pruning example](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/feature/compress/examples/compress). + ## Pre-Requisites For Minitron pruning for Megatron-LM / NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.07`) which has all the dependencies installed. diff --git a/modelopt/torch/_compress/__init__.py b/modelopt/torch/_compress/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/_compress/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/dataset/__init__.py b/modelopt/torch/_compress/dataset/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/_compress/dataset/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/dataset/prepare_dataset.py b/modelopt/torch/_compress/dataset/prepare_dataset.py new file mode 100644 index 0000000000..49d63d1227 --- /dev/null +++ b/modelopt/torch/_compress/dataset/prepare_dataset.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import datasets +import fire +import numpy as np +from logger import mprint + + +def process_and_save_dataset( + dataset_name: str, + output_dir: str, + split: tuple = ("code", "math", "stem", "chat"), + overwrite: bool = False, +): + # Check if output_dir contains an existing dataset + dataset_dict_path = os.path.join(output_dir, "dataset_dict.json") + if os.path.exists(output_dir) and os.path.exists(dataset_dict_path): + if not overwrite: + mprint( + f"Output directory '{output_dir}' already contains a dataset. " + "Use '--overwrite True' to overwrite existing data." + ) + return + + ds = datasets.load_dataset(dataset_name, split=split) + ds = datasets.concatenate_datasets(ds) + # Filter out samples with reasoning = on + ds = ds.filter(lambda x: x["reasoning"] == "off") + # Hardcoded for dynamically create a deterministic train-val split + seed = 408 + generator = np.random.RandomState(seed=seed) + ds_split = ds.train_test_split(test_size=0.05, shuffle=True, generator=generator) + # Rename dataset names to follow previous conventions + ds_dict = datasets.DatasetDict( + { + "train": ds_split["train"], + "valid": ds_split["test"], + } + ) + # Save locally + os.makedirs(output_dir, exist_ok=True) + ds_dict.save_to_disk(output_dir) + + mprint(f"Dataset splits:\n{ds_dict}") + mprint(f"Saved processed datasets to {output_dir}") + + +if __name__ == "__main__": + fire.Fire(process_and_save_dataset) diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index d821fbd029..c17930ec10 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -33,6 +33,7 @@ ) from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.logger import mprint from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.mode import ( @@ -119,17 +120,24 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR ) # Convert Llama3 model to DeciLM model - hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable - convert_llama3_to_decilm( - input_dir=config.input_model_path, - output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, - ) + if runtime.global_rank == 0: + mprint("Compress Progress 2/8: converting model from HF to DeciLM (single-gpu)") + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + convert_llama3_to_decilm( + input_dir=config.input_model_path, + output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, + ) + runtime.wait_for_everyone() # Score_pruning_activations (distributed processing) + mprint("Compress Progress 3/8: scoring pruning activations (multi-gpu)") score_pruning_activations.launch_score_activations(hydra_cfg, runtime) # Prune the model and save pruned checkpoints if runtime.global_rank == 0: + mprint( + "Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" + ) pruning_ckpts.launch_prune_ckpt(hydra_cfg) runtime.wait_for_everyone() @@ -209,11 +217,16 @@ def run_search(self) -> None: # Build_library_and_stats (single process) if runtime.global_rank == 0: + mprint( + "Compress Progress 5/8: building replacement library and subblock statistics (single-gpu)" + ) build_library_and_stats.launch_build_library_and_stats(hydra_cfg) runtime.wait_for_everyone() # Calc_one_block_scores (distributed processing) + mprint("Compress Progress 6/8: calculating one block scores (multi-gpu)") scoring.launch_scoring(hydra_cfg, runtime) # mip_and_realize_models (distributed processing) + mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) diff --git a/modelopt/torch/_compress/tools/logger.py b/modelopt/torch/_compress/tools/logger.py new file mode 100644 index 0000000000..3e8e213ca2 --- /dev/null +++ b/modelopt/torch/_compress/tools/logger.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import inspect +import logging +import os +import sys + +import torch.distributed.launch # noqa: F401 + +logging.getLogger("fsspec.local").setLevel(logging.ERROR) +logging.getLogger("websockets.client").setLevel(logging.WARN) +logging.getLogger("websockets.server").setLevel(logging.WARN) +logging.getLogger("websockets.server:connection").setLevel(logging.WARN) + + +class LogColors: + BLUE = "\033[94m" + CYAN = "\033[96m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + RESET = "\033[0m" + + +class DistributedLogger(logging.Logger): + verbosity = logging.ERROR + + def __init__(self, name, level=logging.DEBUG): + super().__init__(name, level) + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.global_rank = int(os.environ.get("RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + + def dist_log(self, msg: str, ranks: str = "main"): + """ + Log parameter msg with the given ranks. + parameter ranks: + "all": log with all ranks + "main": log with only rank 0 in node 0 + "last": log with only rank -1 in node 0 + "local_main": log with only rank 0 in all nodes + """ + # print(msg, ranks) + if ranks not in ["all", "main", "local_main", "last"]: + raise NotImplementedError( + f"Could not broadcast msg {msg} - " + f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}" + ) + # All ranks to print + if ranks == "all": + pass + + # Only main rank at node 0 to print + elif ( + (ranks == "main" and self.global_rank != 0) + or (ranks == "last" and self.local_rank != self.world_size - 1) + or (ranks == "local_main" and self.local_rank != 0) + ): + return + + message_source = self.get_caller_location() + + self.info( + f"{LogColors.GREEN}[rank-{self.global_rank}]{LogColors.RESET}[{message_source}]\t{msg}" + ) + + # def dist_warning(self, msg): + # if self.verbosity <= logging.WARNING: + # self.warning(f"[rank-{self.global_rank}] " + msg) + + @staticmethod + def get_caller_location() -> str: + # Get the caller's stack frame + frame = inspect.currentframe() + + # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source + caller_frame = frame.f_back.f_back.f_back + + # Get the filename and line number from the caller's stack frame + filename = os.path.basename(caller_frame.f_code.co_filename) + lineno = caller_frame.f_lineno + return f"{filename}:{lineno}" + + +# Initialize logger +logging.setLoggerClass(DistributedLogger) +logger = logging.getLogger(__name__) +logger.propagate = False + +formatter = logging.Formatter("[%(asctime)s]%(message)s") +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(formatter) +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) + +# Manually edit torch logger +torch_logger = logging.getLogger("torch") +torch_logger.handlers = logger.handlers +torch_logger.propagate = False + +# Manually edit deepspeed logger + +# Show some love to Mac & Windows users who can't easily install deepspeed ;) +# This is allowing running tests on Mac & Windows and train in non-DDP +try: + from deepspeed.utils import logger as deepspeed_logger + + deepspeed_logger.handlers = logger.handlers + deepspeed_logger.propagate = False +except ImportError: + # If deepspeed is not installed - no op + pass + +# Define a custom function to redirect warnings to logger +# def custom_warning_handler(message, category, filename, lineno, file=None, line=None): +# logger.dist_warning(f'{category.__name__}: {message} (in {filename}, line {lineno})') + + +# Use the custom warning handler +# warnings.showwarning = custom_warning_handler + +logger: DistributedLogger + + +def aprint(msg: str | None): + """ + All ranks from all nodes prints + """ + return logger.dist_log(msg=msg, ranks="all") + + +def lmprint(msg: str | None): + """ + All local main ranks prints (rank 0 in each node) + """ + return logger.dist_log(msg=msg, ranks="local_main") + + +def mprint(msg: str | None): + """ + Master prints only (rank 0 in node 0) + """ + return logger.dist_log(msg=msg, ranks="main") + + +def lprint(msg: str | None): + """ + Last rank prints only (rank -1 in node 0) + """ + return logger.dist_log(msg=msg, ranks="last") diff --git a/setup.py b/setup.py index 568131f486..2d041b1841 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,8 @@ "fire", "hydra-core==1.3.2", "omegaconf==2.3.0", + "wandb~=0.17.5", + "lru-dict", ], } diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 3d5d6b666d..f945e75ff4 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -40,13 +40,10 @@ # /workspace/puzzletron # # submit_job --partition interactive --time 0 \ -# --image gitlab-master.nvidia.com/deci/puzzletron:trtllm_main \ +# --image gitlab-master.nvidia.com/deci/puzzletron:modelopt_main \ # --workdir $MODELOPT SRC DIRECTORY --interactive --gpu 1 # -# pip install mip -# pip install lru-dict -# -# export PYTHONPATH=$PYTHONPATH:/workspace/puzzletron/v1 +# export PYTHONPATH=$PYTHONPATH:.:/workspace/puzzletron/v1 # # pytest -s -v ./tests/experimental/torch/_compress/test_compress.py::test_compress -o addopts="" From b121945aefa51c2c89f78416c1658bc8b65939dd Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 13 Nov 2025 17:50:45 +0100 Subject: [PATCH 09/62] Add llama converter (no dependency on internal Nvidia code) - part 1/2 (#545) ## What does this PR do? Add llama converter (no dependency on internal Nvidia code) - part 1/2 - change top-level dependencies in convert_llama3_to_decilm.py from puzzle_tools.... to modelopt..... - added modelopt.torch._compress.tools module - remove tokenization_mistral.py - not used scope of 2/2 part (will come once part 1/2 is merged): - change all deeper dependencies from from puzzle_tools.... to modelopt.... - test_convert_llama3_config_to_decilm_config.py should run without any internal nvidia dependencies --------- Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/compress.py | 2 +- .../_compress/decilm/conversion_utils.py | 157 ++++++ .../converters/convert_llama3_to_decilm.py | 6 +- .../deci_lm_hf_code/tokenization_mistral.py | 374 --------------- .../nas/plugins/compress_nas_plugin.py | 4 +- .../torch/_compress/tools/checkpoint_utils.py | 191 ++++++++ .../_compress/tools/checkpoint_utils_hf.py | 448 ++++++++++++++++++ modelopt/torch/_compress/tools/common.py | 22 + modelopt/torch/_compress/{ => tools}/hydra.py | 0 .../torch/_compress/{ => tools}/runtime.py | 0 .../_compress/nas/plugins/test_nas_convert.py | 2 +- .../_compress/nas/plugins/test_nas_search.py | 2 +- .../torch/_compress/test_compress.py | 2 +- 13 files changed, 827 insertions(+), 383 deletions(-) create mode 100644 modelopt/torch/_compress/decilm/conversion_utils.py delete mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py create mode 100644 modelopt/torch/_compress/tools/checkpoint_utils.py create mode 100644 modelopt/torch/_compress/tools/checkpoint_utils_hf.py create mode 100644 modelopt/torch/_compress/tools/common.py rename modelopt/torch/_compress/{ => tools}/hydra.py (100%) rename modelopt/torch/_compress/{ => tools}/runtime.py (100%) diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 455cf3f8ec..df953bb908 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -28,7 +28,7 @@ from omegaconf import DictConfig from puzzle_tools.runtime import IRuntime -from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir def compress( diff --git a/modelopt/torch/_compress/decilm/conversion_utils.py b/modelopt/torch/_compress/decilm/conversion_utils.py new file mode 100644 index 0000000000..deb080ea21 --- /dev/null +++ b/modelopt/torch/_compress/decilm/conversion_utils.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +from collections import defaultdict + +from safetensors.torch import load_file, save_file +from tqdm import tqdm + + +def convert_name(name): + return name.replace("feed_forward", "mlp").replace("language_model.", "") + + +def convert_routed_experts_weight(llama_name, weight): + assert ".experts." in llama_name, "Only use this func to convert weights of routed experts" + llama_name_prefix = llama_name.split(".experts.")[0] + deci_name_prefix = convert_name(llama_name_prefix) + + experts_state_dict = {} + for i_expert, expert_weight in enumerate(weight.unbind(dim=0)): + expert_prefix = f"{deci_name_prefix}.experts.{i_expert}" + if "gate_up_proj" in llama_name: + gate_weight, up_weight = expert_weight.transpose(0, 1).chunk(2, dim=0) + experts_state_dict[f"{expert_prefix}.gate_proj.weight"] = gate_weight.contiguous() + experts_state_dict[f"{expert_prefix}.up_proj.weight"] = up_weight.contiguous() + elif "down_proj" in llama_name: + down_weight = expert_weight.transpose(0, 1) + experts_state_dict[f"{expert_prefix}.down_proj.weight"] = down_weight.contiguous() + else: + raise ValueError(f"Unknown expert weight: {llama_name}") + + return experts_state_dict + + +def get_layer_subblock(param): + if param.startswith("model.embed_tokens."): + return "embeddings" + if param.startswith("lm_head.") or param == "model.norm.weight": + return "lm_head" + m = re.match(r"model\.layers\.(\d+)\.(.+)", param) + if m: + layer, suffix = m.groups() + if suffix.startswith(("self_attn.", "input_layernorm.weight")): + return f"block_{layer}_attention" + elif suffix.startswith(("mlp.", "post_attention_layernorm.weight")): + return f"block_{layer}_ffn" + return None + + +def convert_model_weights_to_decilm(llama_hf_dir, output_dir, is_llama4=False): + index_path = os.path.join(llama_hf_dir, "model.safetensors.index.json") + single_file_path = os.path.join(llama_hf_dir, "model.safetensors") + + # Check if we have a sharded model (with index) or single file model + if os.path.exists(index_path): + # Sharded model - use existing logic + with open(index_path) as f: + index = json.load(f) + param_to_file = index["weight_map"] + all_param_names = list(param_to_file.keys()) + elif os.path.exists(single_file_path): + # Single file model - create a synthetic index + data = load_file(single_file_path) + all_param_names = list(data.keys()) + param_to_file = dict.fromkeys(all_param_names, "model.safetensors") + else: + raise FileNotFoundError( + f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." + ) + name_map = { + name: convert_name(name) + for name in all_param_names + if name.startswith("language_model.") or not is_llama4 + } + + # Reverse map: file -> set of params + file_to_params = defaultdict(set) + for name, file in param_to_file.items(): + file_to_params[file].add(name) + + # Determine subblocks needed + subblocks = defaultdict(list) + for old_name, new_name in name_map.items(): + subblock = get_layer_subblock(new_name) + if subblock: + subblocks[subblock].append((old_name, new_name)) + + # Output directory + out_dir = os.path.join(output_dir, "subblocks_safetensors") + os.makedirs(out_dir, exist_ok=True) + + # New weight index + new_index = {"metadata": {"format": "pt"}, "weight_map": {}} + + # For single file models, load all data once + if os.path.exists(single_file_path) and not os.path.exists(index_path): + all_data = load_file(single_file_path) + else: + all_data = None + + for subblock, param_pairs in tqdm(subblocks.items(), desc="Processing subblocks"): + tensors = {} + + if all_data is not None: + # Single file model - get tensors from pre-loaded data + for old_name, new_name in param_pairs: + if old_name in all_data: + if ".experts." not in old_name: + tensors[new_name] = all_data[old_name] + else: + experts_state_dict = convert_routed_experts_weight( + old_name, all_data[old_name] + ) + tensors.update(experts_state_dict) + else: + # Sharded model - load only needed files for this subblock + param_files = {param_to_file[old] for old, _ in param_pairs} + for file in param_files: + data = load_file(os.path.join(llama_hf_dir, file)) + for old_name, new_name in param_pairs: + if param_to_file[old_name] == file and old_name in data: + if ".experts." not in old_name: + tensors[new_name] = data[old_name] + else: + experts_state_dict = convert_routed_experts_weight( + old_name, data[old_name] + ) + tensors.update(experts_state_dict) + + # Save this subblock + subblock_file = f"{subblock}.safetensors" + save_file(tensors, os.path.join(out_dir, subblock_file)) + + # Update index + for new_name in tensors: + new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}" + + # Save new index file + with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f: + json.dump(new_index, f, indent=2) + + print(f"✅ Finished saving subblocks and index to {output_dir}") diff --git a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py index 96b96f3510..4df9f009a6 100644 --- a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py +++ b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py @@ -21,12 +21,12 @@ import torch from fire import Fire -from puzzle_tools.checkpoint_utils import copy_tokenizer -from puzzle_tools.checkpoint_utils_hf import copy_deci_lm_hf_code -from puzzle_tools.conversion_utils import convert_model_weights_to_decilm from transformers import LlamaConfig +from modelopt.torch._compress.decilm.conversion_utils import convert_model_weights_to_decilm from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.tools.checkpoint_utils import copy_tokenizer +from modelopt.torch._compress.tools.checkpoint_utils_hf import copy_deci_lm_hf_code """ example: diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py deleted file mode 100644 index e67674a092..0000000000 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py +++ /dev/null @@ -1,374 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Based on https://github.com/vllm-project/vllm/blob/739e03b3449a7f3b0a81ebc30b9555305d914e2d/vllm/transformers_utils/tokenizers/mistral.py -# mypy: ignore-errors - -import os -import re -import sys -from pathlib import Path -from shutil import copyfile -from typing import TYPE_CHECKING, Any - -from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer -from transformers.utils import logging - -if TYPE_CHECKING: - from mistral_common.protocol.instruct.request import ChatCompletionRequest - -logger = logging.get_logger(__name__) - - -def _called_from_vllm() -> bool: - frame = sys._getframe(1) - while frame: - mod = frame.f_globals.get("__name__", "") - if mod == "vllm" or mod.startswith("vllm."): - return True - frame = frame.f_back - return False - - -class HFAdaptedMistralTokenizer(PreTrainedTokenizer): - """ - In order to save the tokenizer, do the following: - ``` - # from import HFAdaptedMistralTokenizer - # from mistral_common.tokens.tokenizers.base import SpecialTokens - HFAdaptedMistralTokenizer.register_for_auto_class("AutoTokenizer") - tokenizer = HFAdaptedMistralTokenizer("", chat_template="dummy") - tokenizer.add_special_tokens( - {"additional_special_tokens": [v.value for _, v in SpecialTokens.__members__.items()]} - ) - tokenizer.save_pretrained("") - ``` - """ - - vocab_files_names = {"path_indicator": "tokenizer_config.json"} - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - path_indicator: str, - unk_token: str | None = None, - bos_token: str | None = None, - eos_token: str | None = None, - pad_token: str | None = None, - add_bos_token: bool = True, - add_eos_token: bool = False, - clean_up_tokenization_spaces: bool = False, - **kwargs, - ): - path_indicator: Path = Path(path_indicator) - if path_indicator.name == "tokenizer_config.json": - path_indicator = path_indicator.parent - if path_indicator.is_dir(): - tokenizer_file_name = _find_tokenizer_file(os.listdir(path_indicator)) - tokenizer_file = str(path_indicator / tokenizer_file_name) - else: - tokenizer_file = path_indicator - self._mistral_tokenizer_path = str(tokenizer_file) - - from mistral_common.tokens.tokenizers.mistral import MistralTokenizer as MistralTokenizer - - self._mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file) - self._instruct_tokenizer = self._mistral_tokenizer.instruct_tokenizer - - # Copied from https://github.com/patrickvonplaten/vllm/blob/6cca3d8c330e169bbf386561c441ca5f3879cf85/vllm/transformers_utils/tokenizers/mistral.py - self.version: int = int( - self._instruct_tokenizer.tokenizer.version.value.split("v")[-1].split("m")[0] - ) - - tokenizer_ = self._instruct_tokenizer.tokenizer - from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer - - self.is_tekken = isinstance(tokenizer_, Tekkenizer) - from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer - - self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - if self.is_tekken: - # Make sure special tokens will not raise - tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE - elif self.is_spm: - pass - else: - raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") - - self._vocab = tokenizer_.vocab() - # Convert to a Dict[str, int] to match protocol, but this is a lossy - # conversion. There may be multiple token ids that decode to the same - # string due to partial UTF-8 byte sequences being converted to � - self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} - self._tokenizer = tokenizer_ - self._max_token_id = self.vocab_size - 1 - self.vocab = self._vocab_dict - - bos_token = ( - bos_token - if bos_token - else AddedToken( - self._tokenizer._vocab[self._tokenizer.bos_id], - normalized=False, - special=True, - ) - ) - eos_token = ( - eos_token - if eos_token - else AddedToken( - self._tokenizer._vocab[self._tokenizer.eos_id], - normalized=False, - special=True, - ) - ) - unk_token = ( - unk_token - if unk_token - else AddedToken( - self._tokenizer._vocab[self._tokenizer.unk_id], - normalized=False, - special=True, - ) - ) - pad_token = ( - pad_token - if pad_token - else AddedToken( - self._tokenizer._vocab[self._tokenizer.pad_id], - normalized=False, - special=True, - ) - ) - - self._add_bos_token = add_bos_token - self._add_eos_token = add_eos_token - - self._in_vllm = _called_from_vllm() - - super().__init__( - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - - @property - def vocab_size(self): - """Returns vocab size""" - return self._tokenizer.n_words - - def get_vocab(self): - """Returns vocab as a dict""" - return self._vocab_dict - - def tokenize( - self, - text: str, - pair: str | None = None, - add_special_tokens: bool | None = None, - **kwargs, - ) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokens - - if add_special_tokens is None: - bos = self._add_bos_token - eos = self._add_eos_token - else: - bos = add_special_tokens - eos = add_special_tokens - - input_ids = [] - parts = self.tokens_trie.split(text) - - in_vllm_chat_completion_mode = False - if ( - self._in_vllm - and len(parts) > 1 - and parts[0] == SpecialTokens.bos.value - and parts[1] == SpecialTokens.begin_inst.value - ): - # This is a dangerous hack to make the tokenizer work with vLLM. - # It means we are in chat completion mode. - bos = False - eos = False - in_vllm_chat_completion_mode = True - - if os.environ.get("HF_TOKENIZE_FORCE_NO_SPECIAL_TOKENS", "0") == "1": - bos = False - eos = False - - if not self._in_vllm or in_vllm_chat_completion_mode: - for part in parts: - if part in self.additional_special_tokens and part in self._vocab_dict: - input_ids.append(self._convert_token_to_id(part)) - else: - input_ids.extend(self._tokenizer.encode(part, bos=bos, eos=eos)) - else: - # Doesn't tokenize special tokens properly, but this is the behavior of vLLM when we are in completion mode. - input_ids = self._tokenizer.encode(text, bos=bos, eos=eos) - - if os.environ.get("HF_TOKENIZE_ABUSE", "1") == "1": - # A lot faster than the other option - return input_ids - else: - return [self._convert_id_to_token(token_id) for token_id in input_ids] - - def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: - if len(tokens) > 0 and isinstance(tokens[0], int): - return tokens - return super().convert_tokens_to_ids(tokens) - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self._vocab_dict[token] - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - piece = self._tokenizer.id_to_piece(index) - return piece if isinstance(piece, str) else piece.value - - def convert_tokens_to_string(self, tokens: list[str]) -> str: - from mistral_common.tokens.tokenizers.base import SpecialTokens - - if self.is_tekken: - tokens = [ - t - for t in tokens - if (t is SpecialTokens.tool_calls or t not in self._tokenizer._all_special_tokens) - ] - - if any(isinstance(t, bytes) for t in tokens): - # we need to encode and decode all tokens again - shift = self._tokenizer.num_special_tokens - - def _token_to_id(t: str): - t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t - try: - return shift + self._tokenizer._tekken_token2id_nospecial[t_bytes] - except KeyError: - logger.warning( - "Failed to convert token %s to id, replacing with ", - t_bytes, - ) - return self._tokenizer.unk_id - - ids = [_token_to_id(t) for t in tokens] - decoded = self._tokenizer.decode(ids) - else: - decoded = "".join(tokens) - else: - # make sure certain special tokens like Tool calls are - # not decoded - special_tokens = {SpecialTokens.tool_calls} - regular_tokens: list[str] = [] - decoded_list = [] - - for token in tokens: - if token in special_tokens: - if regular_tokens: - decoded_list.append(self._tokenizer.decode(regular_tokens)) - regular_tokens = [] - decoded_list.append(token) - else: - regular_tokens.append(token) - - if regular_tokens: - decoded_list.append(self._tokenizer.decode(regular_tokens)) # type: ignore[no-untyped-call] - - decoded = "".join(decoded_list) - - return decoded - - def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: - """ - Use this method to save the full tokenizer file. - """ - - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join(save_directory, "tekken.json") - - if os.path.abspath(self._mistral_tokenizer_path) != os.path.abspath(out_vocab_file): - copyfile(self._mistral_tokenizer_path, out_vocab_file) - - return (out_vocab_file,) - - def apply_chat_template( - self, - conversation: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - tokenize: bool = True, - **kwargs, - ) -> list[int]: - request = _make_mistral_chat_completion_request(conversation, tools) - encoded = self._mistral_tokenizer.encode_chat_completion(request) - if tokenize: - # encode-decode to get clean prompt - return encoded.tokens - else: - return encoded.text - - -def _find_tokenizer_file(files: list[str]): - file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") - - matched_files = [file for file in files if file_pattern.match(file)] - if len(matched_files) > 1: - raise OSError( - f"Found {len(matched_files)} files matching the " - f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral " - f"tokenizer is present in {files}." - ) - elif len(matched_files) == 0: - raise OSError( - f"Found {len(matched_files)} files matching the " - f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " - f"tokenizer is present in {files}." - ) - - return matched_files[0] - - -def _make_mistral_chat_completion_request( - messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None -) -> "ChatCompletionRequest": - last_message = messages[-1] - if last_message["role"] == "assistant": - last_message["prefix"] = True - - # mistral-common requires AssistantMessage content to be string [1]. - # - # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 - for message in messages: - if message.get("role") == "assistant": - content = message.get("content") - if isinstance(content, list): - content = "\n".join(chunk.get("text") for chunk in content) - message["content"] = content - - # The Mistral client, in comparison to the OpenAI client, requires the - # "parameters" dict to be present, even if it's empty. - if tools: - for function in [tool["function"] for tool in tools if tool["type"] == "function"]: - if function.get("parameters") is None: - function["parameters"] = {} - - from mistral_common.protocol.instruct.request import ChatCompletionRequest - - return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index c17930ec10..1cbfa5f30d 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -31,9 +31,9 @@ from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import NativeDdpRuntime from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.mode import ( diff --git a/modelopt/torch/_compress/tools/checkpoint_utils.py b/modelopt/torch/_compress/tools/checkpoint_utils.py new file mode 100644 index 0000000000..4a05f82bb0 --- /dev/null +++ b/modelopt/torch/_compress/tools/checkpoint_utils.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import concurrent.futures +import warnings +from functools import partial +from pathlib import Path +from typing import Literal, TypeVar + +import torch +from safetensors.torch import load_file as safe_load_file +from torch import nn +from transformers import AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + +from modelopt.torch._compress.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch._compress.tools.common import infer_weights_dtype + +SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" +PTH_SUBBLOCKS_DIR_NAME = "subblocks" +STATE_DICT_FILE_NAME = "model.pth" + +warnings.filterwarnings("ignore", "You are using `torch.load` with `weights_only=False`*.") + + +def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: + checkpoint_dir = _normalize_checkpoint_dir(checkpoint_dir) + + if (state_dict_path := checkpoint_dir / STATE_DICT_FILE_NAME).exists(): + return torch.load(state_dict_path, map_location="cpu", weights_only=False) + + if (safetensors_subblocks_dir := checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): + return _load_state_dict_from_subblocks(safetensors_subblocks_dir) + + if (pth_subblocks_dir := checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME).exists(): + return _load_state_dict_from_subblocks(pth_subblocks_dir) + + if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or ( + checkpoint_dir / SAFE_WEIGHTS_NAME + ).exists(): + from utils.sharded_checkpoint_utils import ( + load_sharded_state_dict, # local import to avoid circular import + ) + + return load_sharded_state_dict(checkpoint_dir) + + raise FileNotFoundError( + f"Couldn't find state dict path or subblocks dir inside {checkpoint_dir}" + ) + + +def _normalize_checkpoint_dir(checkpoint_dir: Path | str) -> Path: + checkpoint_dir = Path(checkpoint_dir) + if checkpoint_dir.is_file(): + checkpoint_dir = checkpoint_dir.parent + return checkpoint_dir + + +def _load_state_dict_from_subblocks(subblocks_dir: Path) -> dict[str, torch.Tensor]: + torch_paths = list(subblocks_dir.glob("*.pth")) + safetensors_paths = list(subblocks_dir.glob("*.safetensors")) + + if len(torch_paths) != 0: + load_fn = partial(torch.load, map_location="cpu", weights_only=False) + file_paths = torch_paths + elif len(safetensors_paths) != 0: + load_fn = safe_load_file + file_paths = safetensors_paths + else: + raise ValueError(f"No tensor files found in {subblocks_dir=}") + + with concurrent.futures.ThreadPoolExecutor() as executor: + state_dict_shards = list(executor.map(load_fn, file_paths)) + + state_dict = {k: v for shard in state_dict_shards for k, v in shard.items()} + return state_dict + + +NNModule = TypeVar("NNModule", bound=nn.Module) + + +def init_module_with_state_dict( + state_dict: dict[str, torch.Tensor], + module_cls: type[NNModule], + *init_args, + **init_kwargs, +) -> NNModule: + weights_dtype = infer_weights_dtype(state_dict) + module = init_empty_module(module_cls, weights_dtype, *init_args, **init_kwargs) + module.load_state_dict(state_dict) + return module + + +def init_empty_module( + module_cls: type[NNModule], + dtype: torch.dtype, + *init_args, + **init_kwargs, +) -> NNModule: + default_dtype = torch.get_default_dtype() + current_device = torch.ones(1).device + torch.set_default_dtype(dtype) + module = skip_init(module_cls, *init_args, device=current_device, **init_kwargs) + torch.set_default_dtype(default_dtype) + return module + + +def skip_init(module_cls, *args, **kwargs) -> nn.Module: + """ + Heavily inspired by torch.nn.utils.skip_init but does not require the module to accept a "device" kwarg. + """ + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError(f"Expected a Module; got {module_cls}") + + final_device = kwargs.pop("device", "cpu") + with torch.device("meta"): + module = module_cls(*args, **kwargs) + + module = module.to_empty(device=final_device) + return module + + +def is_valid_decilm_checkpoint(checkpoint_dir: Path | str) -> bool: + """Validate that a checkpoint is in DeciLM format (has block_configs). + + Args: + checkpoint_dir: Path to checkpoint directory + + Returns: + True if checkpoint is valid DeciLM format, False otherwise + """ + try: + model_config = load_model_config(checkpoint_dir) + if model_config.block_configs is None: + warnings.warn( + f"Skipping checkpoint '{checkpoint_dir}' - not in DeciLM format (missing block_configs)" + ) + return False + return True + except Exception as e: + warnings.warn(f"Skipping checkpoint '{checkpoint_dir}' - failed to load config: {e}") + return False + + +def copy_tokenizer( + source_dir_or_tokenizer_name: Path | str, + target_dir: Path | str, + on_failure: Literal["raise", "warn"] = "raise", +) -> None: + """ + Prefer loading the tokenizer from huggingface hub (when tokenizer_name.txt file is available) + to avoid collision between transformers versions. + """ + source_tokenizer_name_path = Path(source_dir_or_tokenizer_name) / "tokenizer_name.txt" + if source_tokenizer_name_path.exists(): + source_dir_or_tokenizer_name = source_tokenizer_name_path.read_text().strip() + + tokenizer = None + try: + tokenizer = AutoTokenizer.from_pretrained( + source_dir_or_tokenizer_name, trust_remote_code=True + ) + except Exception: + message = f"Couldn't load tokenizer from '{source_dir_or_tokenizer_name}'" + if on_failure == "raise": + raise FileNotFoundError(message) + else: + warnings.warn(message) + + if tokenizer is not None: + target_dir = Path(target_dir) + target_dir.mkdir(exist_ok=True, parents=True) + tokenizer.save_pretrained(target_dir) + + target_tokenizer_name_path = target_dir / "tokenizer_name.txt" + is_given_tokenizer_name_as_argument = not Path(source_dir_or_tokenizer_name).exists() + if is_given_tokenizer_name_as_argument: + target_tokenizer_name_path.write_text(source_dir_or_tokenizer_name) diff --git a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py new file mode 100644 index 0000000000..c686c10272 --- /dev/null +++ b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py @@ -0,0 +1,448 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import concurrent.futures +import fcntl +import os +import shutil +import time +import warnings +from collections import defaultdict +from collections.abc import Callable, Mapping +from pathlib import Path +from typing import Any, BinaryIO + +import torch +from logger import mprint +from puzzle_tools import deci_lm_hf_code +from puzzle_tools.common import infer_weights_dtype +from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from puzzle_tools.robust_json import json_dumps +from safetensors.torch import save_file as safe_save_file +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from utils.post_init_sparse import SparsityMethod + +SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" +PTH_SUBBLOCKS_DIR_NAME = "subblocks" +RELATIVE_SUBBLOCKS_DIR = Path(SAFETENSORS_SUBBLOCKS_DIR_NAME) + + +# TODO: (esegal) Should ask the model for something like this +NON_LAYER_MODULE_TO_FILE_TYPE = { + "model.embed_tokens": "embeddings", + "model.norm": "lm_head", + "lm_head": "lm_head", +} +MODULE_WITHIN_LAYER_TO_FILE_TYPE = { + "input_layernorm": "attention", + "self_attn": "attention", + "post_attention_layernorm": "ffn", + "mlp": "ffn", + "parallel_blocks": "multi_block", +} +LAYERS_MODULE_NAME = "model.layers" + +warnings.filterwarnings("ignore", "You are using `torch.load` with `weights_only=False`*.") + + +def load_checkpoint( + checkpoint_dir: Path | str, + model_config_overrides: dict | None = None, + ignore_unexpected_config_keys: bool = False, +) -> DeciLMForCausalLM: + """ + Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your + local repo code, not the code inside the checkpoint. + """ + from modelopt.torch._compress.tools.checkpoint_utils import ( + load_state_dict, # prevent circular import + ) + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + model_config = load_model_config( + checkpoint_dir, model_config_overrides, ignore_unexpected_config_keys + ) + + # Without sparsity we could have done: + # model = DeciLMForCausalLM.from_pretrained(pretrained_model_name_or_path=checkpoint_dir, config=model_config) + state_dict = load_state_dict(checkpoint_dir) + state_dict, sparsity_masks = SparsityMethod.fix_state_dict_inplace(state_dict, verbose=True) + dtype = infer_weights_dtype(state_dict) + model = DeciLMForCausalLM.from_pretrained( + pretrained_model_name_or_path=None, + config=model_config, + state_dict=state_dict, + torch_dtype=dtype, + ) + SparsityMethod().apply_masks(model, sparsity_masks) + + return model + + +def load_model_config( + checkpoint_dir: Path | str, + model_config_overrides: Mapping | None = None, + ignore_unexpected_config_keys: bool = False, +) -> DeciLMConfig: + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + if model_config_overrides is None: + model_config_overrides = {} + + config, unused_kwargs = DeciLMConfig.from_pretrained( + checkpoint_dir, return_unused_kwargs=True, **model_config_overrides + ) + + if not ignore_unexpected_config_keys: + if unused_kwargs: + raise ValueError(f"Unexpected config keys: {unused_kwargs.keys()}") + + return config + + +def save_checkpoint(model: DeciLMForCausalLM, checkpoint_dir: Path | str) -> None: + _save_checkpoint(model.config, model.state_dict(), checkpoint_dir) + + +def _save_checkpoint( + model_config: DeciLMConfig, + state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + max_workers: int | None = None, # Now optional - will auto-calculate if None +) -> None: + mprint("=== Starting _save_checkpoint detailed profiling ===") + total_start_time = time.time() + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + # Phase 1: Create directory and save config + phase1_start_time = time.time() + checkpoint_dir.mkdir(parents=True, exist_ok=True) + model_config.save_pretrained(checkpoint_dir) + phase1_time = time.time() - phase1_start_time + mprint(f"Phase 1 - Directory creation and config save: {phase1_time:.2f}s") + + # Phase 2: Save subblocks (main model weights) with auto-calculated worker count + phase2_start_time = time.time() + save_subblocks( + state_dict, + checkpoint_dir, + multi_threaded=True, + max_workers=max_workers, # Will auto-calculate if None + ) + phase2_time = time.time() - phase2_start_time + mprint(f"Phase 2 - Save subblocks (model weights): {phase2_time:.2f}s") + + # Phase 3: Save safetensors index + phase3_start_time = time.time() + save_safetensors_index(model_config, checkpoint_dir) + phase3_time = time.time() - phase3_start_time + mprint(f"Phase 3 - Save safetensors index: {phase3_time:.2f}s") + + # Phase 4: Copy HF code + phase4_start_time = time.time() + copy_deci_lm_hf_code(checkpoint_dir) + phase4_time = time.time() - phase4_start_time + mprint(f"Phase 4 - Copy HF code: {phase4_time:.2f}s") + + total_time = time.time() - total_start_time + mprint(f"=== _save_checkpoint completed in {total_time:.2f}s ===") + mprint( + f"Breakdown: Config {phase1_time:.1f}s + Subblocks {phase2_time:.1f}s + " + f"Index {phase3_time:.1f}s + HF code {phase4_time:.1f}s" + ) + mprint( + f"Save percentage breakdown: Config {phase1_time / total_time * 100:.1f}% + " + f"Subblocks {phase2_time / total_time * 100:.1f}% + " + f"Index {phase3_time / total_time * 100:.1f}% + " + f"HF code {phase4_time / total_time * 100:.1f}%" + ) + + # Performance metrics + if phase2_time > 0: + subblocks_percentage = phase2_time / total_time * 100 + actual_workers = max_workers if max_workers else "auto" + mprint( + f"I/O optimization: Subblocks were {subblocks_percentage:.1f}% of total save time " + f"(max_workers={actual_workers})" + ) + + +def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: + from modelopt.torch._compress.tools.checkpoint_utils import ( + load_state_dict, # prevent circular import + ) + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + model_config = load_model_config(checkpoint_dir) + state_dict = load_state_dict(checkpoint_dir) + save_subblocks(state_dict, checkpoint_dir) + + if (index_path := checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists(): + index_path.rename(checkpoint_dir / f"before_splitting.{SAFE_WEIGHTS_INDEX_NAME}") + save_safetensors_index(model_config, checkpoint_dir) + + +def save_subblocks( + state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + multi_threaded: bool = True, + max_workers: int | None = None, # Now optional - will auto-calculate if None +) -> None: + mprint("=== Starting save_subblocks detailed profiling ===") + subblocks_start_time = time.time() + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + # Step 1: Build weight map + weight_map_start_time = time.time() + weight_map = _build_safetensors_weight_map( + state_dict=state_dict, + non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, + module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, + layers_module_name=LAYERS_MODULE_NAME, + ) + weight_name_to_filename = {k: checkpoint_dir / v for k, v in weight_map.items()} + weight_map_time = time.time() - weight_map_start_time + mprint(f" Step 1 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") + + # Step 2: Create subblocks directory + dir_create_start_time = time.time() + subblocks_path = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_path.mkdir(parents=True, exist_ok=True) + dir_create_time = time.time() - dir_create_start_time + mprint(f" Step 2 - Create directory: {dir_create_time:.2f}s") + + # Step 3: Organize tensors by file + organize_start_time = time.time() + filename_to_partial_state_dict = defaultdict(dict) + total_tensor_size = 0 + for weight_name, weight in state_dict.items(): + if weight_name in weight_map: + # Ensure tensor is contiguous and on CPU for faster I/O + tensor = ( + weight.contiguous().cpu() if weight.device.type != "cpu" else weight.contiguous() + ) + filename_to_partial_state_dict[weight_name_to_filename[weight_name]][weight_name] = ( + tensor + ) + total_tensor_size += weight.numel() * weight.element_size() + organize_time = time.time() - organize_start_time + mprint( + f" Step 3 - Organize tensors: {organize_time:.2f}s ({total_tensor_size / (1024**3):.2f}GB total)" + ) + + # Step 4: Prepare save arguments and auto-calculate optimal I/O workers + prepare_start_time = time.time() + safe_save_kwargs = [ + {"tensors": partial_state_dict, "filename": filename, "metadata": {"format": "pt"}} + for filename, partial_state_dict in filename_to_partial_state_dict.items() + ] + + # Auto-calculate optimal I/O workers: min(cpu_count, num_files) + if max_workers is None: + cpu_count = os.cpu_count() or 1 + num_files = len(safe_save_kwargs) + max_workers = min(cpu_count, num_files) + mprint( + f" Auto-calculated I/O workers: min({cpu_count} CPUs, {num_files} files) = {max_workers}" + ) + else: + mprint(f" Using specified I/O workers: {max_workers}") + + prepare_time = time.time() - prepare_start_time + mprint(f" Step 4 - Prepare save args: {prepare_time:.2f}s ({len(safe_save_kwargs)} files)") + + # Step 5: Save files with optimal worker count + save_start_time = time.time() + if multi_threaded: + mprint(f" Using multi-threaded saving with {max_workers} workers...") + + def optimized_safe_save(kwargs): + try: + safe_save_file(**kwargs) + return True + except Exception as e: + mprint(f" Error saving {kwargs['filename']}: {e}") + return False + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list(executor.map(optimized_safe_save, safe_save_kwargs)) + + # Check for any failures + failed_saves = sum(1 for r in results if not r) + if failed_saves > 0: + mprint(f" Warning: {failed_saves} files failed to save") + else: + mprint(" Using single-threaded saving...") + for kwargs in safe_save_kwargs: + safe_save_file(**kwargs) + + save_time = time.time() - save_start_time + mprint(f" Step 5 - Save files: {save_time:.2f}s ({max_workers} workers)") + + subblocks_total_time = time.time() - subblocks_start_time + mprint(f"=== save_subblocks completed in {subblocks_total_time:.2f}s ===") + mprint( + f" Breakdown: WeightMap {weight_map_time:.1f}s + DirCreate {dir_create_time:.1f}s + " + f"Organize {organize_time:.1f}s + Prepare {prepare_time:.1f}s + Save {save_time:.1f}s" + ) + + # Calculate effective I/O speed + io_speed_gbps = (total_tensor_size / (1024**3)) / save_time if save_time > 0 else 0 + mprint(f" Effective I/O speed: {io_speed_gbps:.2f} GB/s ({max_workers} workers)") + mprint(f" Save operation was {save_time / subblocks_total_time * 100:.1f}% of total time") + + +def save_safetensors_index( + model_config: DeciLMConfig, + checkpoint_dir: Path | str, +) -> None: + mprint("=== Starting save_safetensors_index profiling ===") + index_start_time = time.time() + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + # Step 1: Create fake model on meta device + fake_model_start_time = time.time() + with torch.device("meta"): + fake_model = DeciLMForCausalLM(model_config) + fake_model_time = time.time() - fake_model_start_time + mprint(f" Step 1 - Create fake model: {fake_model_time:.2f}s") + + # Step 2: Build weight map + weight_map_start_time = time.time() + weight_map = _build_safetensors_weight_map( + state_dict=fake_model.state_dict(), + non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, + module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, + layers_module_name=LAYERS_MODULE_NAME, + ) + weight_map_time = time.time() - weight_map_start_time + mprint(f" Step 2 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") + + # Step 3: Create and write index + write_start_time = time.time() + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME + index_json = json_dumps(index) + _write_file_process_safe(index_json, index_path) + write_time = time.time() - write_start_time + mprint(f" Step 3 - Write index file: {write_time:.2f}s ({len(index_json)} chars)") + + index_total_time = time.time() - index_start_time + mprint(f"=== save_safetensors_index completed in {index_total_time:.2f}s ===") + mprint( + f" Breakdown: FakeModel {fake_model_time:.1f}s + WeightMap {weight_map_time:.1f}s + Write {write_time:.1f}s" + ) + + +def _write_text(content: str, f: BinaryIO) -> None: + f.write(content.encode("utf-8")) + + +def _write_file_process_safe( + content: Any, + path: Path | str, + write_fn: Callable[[Any, BinaryIO], None] = _write_text, +) -> None: + """ + Write a file in a multi-process safe way. + If another process tries to write the same file using this method, the current process + "gives up" and assumes that the matter is being taken care of by another process. + + write_fn is a function that receives file contents and a binary file object, + and writes the content to the file. It can be _write_text (defined above), or torch.save, + or a similar function (not safetensors.torch.save_file since it expects a path). + """ + with open(path, "wb") as f: + # Try to acquire an exclusive, non-blocking lock + try: + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + return # Exit immediately if the lock is not acquired + + write_fn(content, f) # Write the content if lock is acquired + f.flush() # Ensure data is written to disk + + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) + + +def _build_safetensors_weight_map( + *, + state_dict: dict[str, torch.Tensor], + non_layer_module_to_file_type: dict[str, str], + module_within_layer_to_file_type: dict[str, str], + layers_module_name: str, +) -> dict[str, Path]: + weight_map = {} + unmapped_weight_names = [] + for weight_name in state_dict: + found_match = False + for module_name, file_type in non_layer_module_to_file_type.items(): + if weight_name.startswith(f"{module_name}."): + weight_map[weight_name] = str(RELATIVE_SUBBLOCKS_DIR / f"{file_type}.safetensors") + found_match = True + if not found_match: + if weight_name.startswith(f"{layers_module_name}."): + name_parts = weight_name[len(layers_module_name) + 1 :].split(".") + layer_index = name_parts[0] + name_within_layer = ".".join(name_parts[1:]) + + for module_name, file_type in module_within_layer_to_file_type.items(): + if name_within_layer.startswith(f"{module_name}."): + weight_map[weight_name] = str( + RELATIVE_SUBBLOCKS_DIR / f"block_{layer_index}_{file_type}.safetensors" + ) + found_match = True + + if not found_match: + unmapped_weight_names.append(weight_name) + + if len(unmapped_weight_names) > 0: + raise ValueError( + f"Unmapped weight names: {unmapped_weight_names}\n" + f"Add them to the `non_layer_module_to_file_type` or " + f"`module_within_layer_to_file_type` dictionaries." + ) + + return weight_map + + +# Not really needed +def save_model_config(model_config: DeciLMConfig, checkpoint_dir: Path | str) -> None: + model_config.save_pretrained(checkpoint_dir) + + +def copy_deci_lm_hf_code(output_dir: Path | str) -> None: + """ + Copy the deci_lm_hf_code directory to the output directory. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + code_dir = Path(deci_lm_hf_code.__file__).parent + for path in code_dir.glob("*.py"): + shutil.copy(path, output_dir / path.name) diff --git a/modelopt/torch/_compress/tools/common.py b/modelopt/torch/_compress/tools/common.py new file mode 100644 index 0000000000..96db572802 --- /dev/null +++ b/modelopt/torch/_compress/tools/common.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def infer_weights_dtype(state_dict: dict[str, torch.Tensor]) -> torch.dtype: + weights_dtype = [p.dtype for p in state_dict.values() if torch.is_floating_point(p)] + weights_dtype = weights_dtype[0] if len(weights_dtype) > 0 else torch.get_default_dtype() + return weights_dtype diff --git a/modelopt/torch/_compress/hydra.py b/modelopt/torch/_compress/tools/hydra.py similarity index 100% rename from modelopt/torch/_compress/hydra.py rename to modelopt/torch/_compress/tools/hydra.py diff --git a/modelopt/torch/_compress/runtime.py b/modelopt/torch/_compress/tools/runtime.py similarity index 100% rename from modelopt/torch/_compress/runtime.py rename to modelopt/torch/_compress/tools/runtime.py diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py index 7dc2d72285..47ff2531da 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py @@ -24,7 +24,7 @@ import modelopt.torch.nas as mtn from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.runtime import NativeDdpRuntime # diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py index 04707d20f0..df3c1e4856 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py @@ -27,7 +27,7 @@ import modelopt.torch.nas as mtn from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.runtime import NativeDdpRuntime def test_nas_search(project_root_path: Path, tmp_path: Path): diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index f945e75ff4..96af36b5e9 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -26,7 +26,7 @@ from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.runtime import NativeDdpRuntime # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) # using a one-click command. From 866e4001f0d7ecfda8b8b884aa47b3b52944a73f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 14 Nov 2025 19:03:32 +0100 Subject: [PATCH 10/62] llama converter is self-contained now (no dependency on internal nvidia code) (#552) ## What does this PR do? llama converter is self-contained now (no dependency on internal nvidia code) --------- Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/compress.py | 2 +- .../nas/plugins/compress_nas_plugin.py | 5 +- .../torch/_compress/tools/checkpoint_utils.py | 7 +- .../_compress/tools/checkpoint_utils_hf.py | 20 +- modelopt/torch/_compress/tools/hydra_utils.py | 81 ++++ .../torch/_compress/tools/post_init_sparse.py | 129 ++++++ modelopt/torch/_compress/tools/robust_json.py | 72 +++ .../tools/sharded_checkpoint_utils.py | 422 ++++++++++++++++++ .../torch/_compress/compress_test_utils.py | 3 +- ..._convert_llama3_config_to_decilm_config.py | 0 tests/gpu/torch/conftest.py | 8 + 11 files changed, 738 insertions(+), 11 deletions(-) create mode 100644 modelopt/torch/_compress/tools/hydra_utils.py create mode 100644 modelopt/torch/_compress/tools/post_init_sparse.py create mode 100644 modelopt/torch/_compress/tools/robust_json.py create mode 100644 modelopt/torch/_compress/tools/sharded_checkpoint_utils.py rename tests/{experimental => gpu}/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py (100%) diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index df953bb908..7d955c5cae 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -28,7 +28,7 @@ from omegaconf import DictConfig from puzzle_tools.runtime import IRuntime -from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir def compress( diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 1cbfa5f30d..13d418b69d 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -15,6 +15,9 @@ """ Compress NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). + +It is used by mtn.convert() to convert a model from HF format to DeciLM format + do pruning scoring +and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search. """ import datetime @@ -31,7 +34,7 @@ from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.runtime import NativeDdpRuntime from modelopt.torch.nas.conversion import NASModeRegistry diff --git a/modelopt/torch/_compress/tools/checkpoint_utils.py b/modelopt/torch/_compress/tools/checkpoint_utils.py index 4a05f82bb0..43d3c43641 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils.py @@ -14,6 +14,11 @@ # limitations under the License. # mypy: ignore-errors +""" +It provides general utilities for loading and initializing PyTorch model checkpoints, +particularly for DeciLM models. +""" + import concurrent.futures import warnings from functools import partial @@ -51,7 +56,7 @@ def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or ( checkpoint_dir / SAFE_WEIGHTS_NAME ).exists(): - from utils.sharded_checkpoint_utils import ( + from modelopt.torch._compress.tools.sharded_checkpoint_utils import ( load_sharded_state_dict, # local import to avoid circular import ) diff --git a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py index c686c10272..3c73498d5f 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py @@ -14,6 +14,11 @@ # limitations under the License. # mypy: ignore-errors +""" +Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, +particularly for DeciLM models. +""" + import concurrent.futures import fcntl import os @@ -26,15 +31,16 @@ from typing import Any, BinaryIO import torch -from logger import mprint -from puzzle_tools import deci_lm_hf_code -from puzzle_tools.common import infer_weights_dtype -from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM -from puzzle_tools.robust_json import json_dumps from safetensors.torch import save_file as safe_save_file from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from utils.post_init_sparse import SparsityMethod + +from modelopt.torch._compress.decilm import deci_lm_hf_code +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch._compress.tools.common import infer_weights_dtype +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.post_init_sparse import SparsityMethod +from modelopt.torch._compress.tools.robust_json import json_dumps SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" PTH_SUBBLOCKS_DIR_NAME = "subblocks" diff --git a/modelopt/torch/_compress/tools/hydra_utils.py b/modelopt/torch/_compress/tools/hydra_utils.py new file mode 100644 index 0000000000..64c4035656 --- /dev/null +++ b/modelopt/torch/_compress/tools/hydra_utils.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for hydra config initialization. +""" + +import datetime +import random +from pathlib import Path + +from hydra import compose, initialize, initialize_config_dir +from hydra.utils import get_object +from omegaconf import DictConfig, OmegaConf + + +def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int: + """ + Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage. + Used as a resolver in hydra configs. + """ + steps = (int(tokens) // int(block)) // int(mbs) + w = pct * steps + return max(1, round(w)) + + +def register_hydra_resolvers(): + OmegaConf.register_new_resolver("to_path", lambda x: Path(x)) + OmegaConf.register_new_resolver( + "random_int", lambda low, high: random.randint(int(low), int(high)) + ) + OmegaConf.register_new_resolver( + "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None + ) + OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p)) + OmegaConf.register_new_resolver("get_object", lambda x: get_object(x)) + + +def initialize_hydra_config_for_dir( + config_dir: str, config_name: str, overrides: list[str] +) -> DictConfig: + """Initialize a hydra config from an absolute path for a config directory + + Args: + config_dir (str): + config_name (str): + overrides (List[str]): + + Returns: + DictConfig: + """ + + with initialize_config_dir(version_base=None, config_dir=config_dir): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args + + +def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig: + with initialize(version_base=None, config_path=config_path): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args diff --git a/modelopt/torch/_compress/tools/post_init_sparse.py b/modelopt/torch/_compress/tools/post_init_sparse.py new file mode 100644 index 0000000000..824d0856ca --- /dev/null +++ b/modelopt/torch/_compress/tools/post_init_sparse.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import torch +from torch import nn +from torch.nn.utils.prune import custom_from_mask + +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM + +""" +Converts a state dictionary from PyTorch's pruning format (with _orig and _mask suffixes) +into a standard format with sparsified weights. +""" + + +class SparsityMethod: + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + gets a model state_dict, returns a state_dict-like mask_dict with masks + """ + + @staticmethod + def fix_state_dict_inplace(state_dict, verbose=False, change_dtype=False): + sparsity_masks = {} + for name in list(state_dict.keys()): + original_name = name.replace("_orig", "") + mask_name = original_name + "_mask" + if name[-4:] == "orig" and mask_name in state_dict: + val = state_dict[name] + mask = state_dict[name[:-4] + "mask"] + val[mask == 0] = 0 + sparsity = (val == 0).sum() / mask.numel() + sparsity_masks[original_name[:-7]] = mask + if verbose: + print(f"fix_state_dict_inplace: {name} {sparsity=}") + del state_dict[mask_name] + del state_dict[name] + state_dict[original_name] = val + if change_dtype: + for name in state_dict: + state_dict[name] = state_dict[name].to(torch.bfloat16) + return state_dict, sparsity_masks + + def filter_function(self): + pass + + def apply_masks(self, model: nn.Module, mask_dict: dict[str, torch.Tensor]) -> None: + for name, module in model.named_modules(): + if name in mask_dict: + custom_from_mask(module, "weight", mask_dict[name].to(module.weight.device)) + print(name) + print(torch.sum(mask_dict[name]) / mask_dict[name].numel()) + + def do_sparsity(self, model: DeciLMForCausalLM, mask_dict=None): + full_name_layers = [] + for block_idx, block_config in enumerate(model.config.block_configs): + ffn_names = block_config.ffn.sparsify # layers_to_sparsify_pattern[block_idx] + att_name = block_config.attention.sparsify + block = model.model.layers[block_idx] + if hasattr(block, "mlp"): + for name, m in block.mlp.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, ffn_names): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "mlp." + name + ) + if hasattr(block, "self_attn"): + for name, m in block.self_attn.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, att_name): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "self_attn." + name + ) + + if mask_dict is None: + state_dict_for_sparsifying = { + k.rstrip(".weight"): v + for k, v in model.state_dict().items() + if k.rstrip(".weight") in full_name_layers + } + mask_dict = self.calculate_masks(state_dict_for_sparsifying) + # print('Apply sparsity') + # print(full_name_layers) + # print(model.state_dict().keys()) + # print(list(mask_dict.keys())) + + self.apply_masks(model, mask_dict) + + +class SparsityMethod2o4(SparsityMethod): + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + gets a model state_dict, returns a state_dict-like mask_dict with masks + """ + mask_dict = {} + for key, val in state_dict.items(): + orig_size = val.shape + scores = val.flatten() ** 2 + mask = self.create_mask(scores) + mask = mask.reshape(orig_size) + mask_dict[key] = mask + return mask_dict + + def create_mask(self, score, value=0): + score = score # .cpu() + orig_size = score.shape + score = score.view(-1, 4) + mask = torch.zeros(score.shape) + values, indices = torch.topk(score, 2, dim=1) + rows = torch.arange(mask.size(0)).unsqueeze(-1) + mask[rows, indices] = 1 + mask = mask.view(orig_size) + return mask # dev = score.device, return mask.to(dev) + + @staticmethod + def filter_function(name, modules_to_sparsify_in_block): + if modules_to_sparsify_in_block is None: + return False + return name in modules_to_sparsify_in_block diff --git a/modelopt/torch/_compress/tools/robust_json.py b/modelopt/torch/_compress/tools/robust_json.py new file mode 100644 index 0000000000..dbb561b828 --- /dev/null +++ b/modelopt/torch/_compress/tools/robust_json.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Provides a robust JSON encoder that can handle various types of objects, +including dataclasses, paths, enums, namespaces, and functions. +""" + +import argparse +import dataclasses +import datetime +import inspect +import json +from enum import Enum +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, ListConfig, OmegaConf + + +class RobustJSONEncoder(json.JSONEncoder): + def default(self, o): + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + if isinstance(o, Path): + return str(o) + if isinstance(o, Enum): + return o.name + if isinstance(o, argparse.Namespace): + return vars(o) + if type(o).__name__ == "dtype": + return str(o) + if isinstance(o, (DictConfig, ListConfig)): + return OmegaConf.to_container(o, resolve=True) + if inspect.isfunction(o) or inspect.ismethod(o): + if o.__module__ == "__main__": + # User-defined function in main — fallback to just the name + return o.__name__ + return f"{o.__module__}.{o.__qualname__}" + if isinstance(o, datetime.timedelta): + return str(o) + return super().default(o) + + +def json_dumps(obj: Any) -> str: + return json.dumps(obj, cls=RobustJSONEncoder, indent=2) + + +def json_dump(obj: Any, path: Path | str) -> None: + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + json_text = json_dumps(obj) + path.write_text(json_text) + + +def json_load(path: Path | str) -> dict: + path = Path(path) + text = path.read_text() + return json.loads(text) diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py new file mode 100644 index 0000000000..91fcb5ebd5 --- /dev/null +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -0,0 +1,422 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Provides utilities for distributed loading, saving, and manipulation of +large language model checkpoints across multiple GPUs/processes. +""" + +import json +from collections.abc import Iterable, Mapping +from pathlib import Path +from typing import Literal, cast + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors import safe_open +from safetensors.torch import load_file as safe_load_file +from safetensors.torch import save_file as safe_save_file +from tqdm import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils.hub import cached_file, get_checkpoint_shard_files +from typing_extensions import override +from utils.utils import EmptyInitOnDevice + +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMDecoderLayer, + DeciLMForCausalLM, + rope_type_to_class, +) +from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import IRuntime + + +class DummyModule(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) + + @staticmethod + def load_state_dict_post_hook( + module: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys + ) -> None: + incompatible_keys.missing_keys.clear() + incompatible_keys.unexpected_keys.clear() + + +class DummyBlock(DummyModule): + def __init__(self, config: DeciLMConfig, block_index: int): + super().__init__() + self.config = config + self.block_index = block_index + + @override + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor | tuple[torch.Tensor, None]: + if self.config.block_return_only_hidden_states: + return x + else: + return x, None + + +class DummyWTE(DummyModule): + def __init__(self, config: DeciLMConfig, dtype: torch.dtype | None = None): + super().__init__() + self.n_embd = config.get_hidden_size() + self.dtype = dtype + + @override + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + B, T = input_ids.shape # noqa: N806 + result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device) + return result + + +class DummyLMHead(DummyModule): + def __init__(self, config: DeciLMConfig): + super().__init__() + self.vocab_size = config.vocab_size + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.shape # noqa: N806 + result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device) + return result + + +def create_local_shard_(model: DeciLMForCausalLM, owned_block_indexes: set[int]): + all_block_indexes = set(range(len(model.model.layers))) + has_first_block = 0 in owned_block_indexes + has_last_block = max(all_block_indexes) in owned_block_indexes + + unowned_block_indexes = all_block_indexes - owned_block_indexes + for block_index in unowned_block_indexes: + model.model.layers[block_index] = cast( + "DeciLMDecoderLayer", DummyBlock(model.config, block_index) + ) + + if not has_first_block: + model.set_input_embeddings(DummyWTE(model.config)) + + if not has_last_block: + model.model.set_final_layer_norm(nn.Identity()) + if not (model.config.tie_word_embeddings and has_first_block): + model.set_output_embeddings(DummyLMHead(model.config)) + + return model + + +def create_dummy_model( + model_config: DeciLMConfig, + dtype: torch.dtype, +) -> DeciLMForCausalLM: + with torch.device("meta"): + model = DeciLMForCausalLM(model_config) + + rope_cls = rope_type_to_class[model_config.position_embedding_type] + model.model.rotary_emb = rope_cls(config=model.config) + + model.model.set_input_embeddings(DummyWTE(model.config, dtype)) + model.model.set_final_layer_norm(nn.Identity()) + model.set_output_embeddings(DummyLMHead(model.config)) + + for block_index in range(model_config.get_num_hidden_layers()): + model.model.layers[block_index] = DummyBlock(model.config, block_index) + + return model + + +def load_and_shard_model( + runtime: IRuntime, + checkpoint_path: str | Path, + owned_block_indexes: set[int] | Literal["auto"] = "auto", + model_config: DeciLMConfig | None = None, + model_config_overrides: Mapping | None = None, +) -> DeciLMForCausalLM: + checkpoint_path = Path(checkpoint_path) + with runtime.device: + if model_config is None: + model_config = load_model_config( + checkpoint_path, model_config_overrides, ignore_unexpected_config_keys=True + ) + + if owned_block_indexes == "auto": + owned_block_indexes = set( + np.array_split(np.arange(model_config.get_num_hidden_layers()), runtime.world_size)[ + runtime.global_rank + ] + ) + + mprint("Initializing model shards") + model_shard = create_sharded_model( + runtime=runtime, + model_config=model_config, + owned_block_indexes=owned_block_indexes, + ) + + if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( + checkpoint_path / SAFE_WEIGHTS_INDEX_NAME + ).exists(): + mprint("Loading shard state_dict from safetensors") + shard_keys = [ + *[name for name, _ in model_shard.named_parameters()], + *[name for name, _ in model_shard.named_buffers()], + ] + shard_state_dict = load_sharded_state_dict( + model_name_or_path=str(checkpoint_path), + keys_to_load=shard_keys, + device=runtime.device, + ) + + new_names = set(shard_state_dict.keys()) + mprint(f"{new_names=}") + model_shard.load_state_dict(shard_state_dict, assign=True) + + del shard_state_dict + + if model_config.tie_word_embeddings and (0 in owned_block_indexes): + # re-tie the weights in case the connection was severed + model_shard.tie_weights() + else: + mprint("Loading state_dict in main process") + state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None + + mprint("Distributing model to shards") + load_state_dict_to_shards( + runtime=runtime, model_shard=model_shard, loaded_state_dict=state_dict + ) + del state_dict + + model_shard.type(runtime.dtype) + + params_on_meta_device = [ + param_name + for param_name, param in model_shard.named_parameters() + if param.device == torch.device("meta") + ] + assert len(params_on_meta_device) == 0, ( + f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" + ) + + return model_shard + + +def create_sharded_model( + runtime: IRuntime, + model_config: DeciLMConfig, + owned_block_indexes: set[int], + device: str | torch.device | None = "meta", + dtype: torch.dtype | None = torch.float32, +): + if isinstance(device, str): + device = torch.device(device) + + runtime.wait_for_everyone() + + with EmptyInitOnDevice(device="meta", dtype=dtype): + model = DeciLMForCausalLM(model_config) + create_local_shard_(model=model, owned_block_indexes=owned_block_indexes) + + if device != torch.device("meta"): + local_shard_state_dict = { + k: torch.empty_like(v, device=device) for k, v in model.state_dict().items() + } + + model.load_state_dict(local_shard_state_dict, assign=True) + + return model + + +def load_state_dict_to_shards( + runtime: IRuntime, model_shard: torch.nn.Module, loaded_state_dict: dict | None = None +) -> None: + from sewing_kit.utils import distributed_isend_obj, distributed_recv_obj + + model_shard.to("meta") + local_state_dict_keys = list(model_shard.state_dict().keys()) + + if runtime.is_main_process: + gathered_state_dict_keys = [None] * runtime.world_size + torch.distributed.gather_object(local_state_dict_keys, gathered_state_dict_keys) + + assert loaded_state_dict is not None + loaded_state_dict = {k.replace("_orig_mod.", ""): v for k, v in loaded_state_dict.items()} + + works: list[torch.distributed.Work] = [] + for i, shard_keys in enumerate(gathered_state_dict_keys[1:]): + process_id = i + 1 + shard_state_dict = {k: v for k, v in loaded_state_dict.items() if k in shard_keys} + process_works = distributed_isend_obj(shard_state_dict, process_id) + works.extend(process_works) + + for work in works: + work.wait() + + shard_state_dict = { + k: v for k, v in loaded_state_dict.items() if k in local_state_dict_keys + } + else: + torch.distributed.gather_object(local_state_dict_keys) + shard_state_dict = distributed_recv_obj() + + print(f"{runtime.global_rank=} loaded state_dict shard") + + missing_keys, unexpected_keys = model_shard.load_state_dict( + shard_state_dict, strict=False, assign=True + ) + assert len(unexpected_keys) == 0 + assert all("dummy_param" in key for key in missing_keys) + + model_shard.to(runtime.device) + + runtime.wait_for_everyone() + + +def save_sharded_model( + runtime: IRuntime, + model_shard: torch.nn.Module | dict[str, torch.Tensor], + out_path: str | Path, +): + """ + out_path is usually output_checkpoint_path / "model.safetensors" + """ + runtime.wait_for_everyone() + + if isinstance(model_shard, torch.nn.Module): + shard_state_dict = model_shard.state_dict() + elif isinstance(model_shard, dict): + shard_state_dict = model_shard + else: + raise ValueError(f"Unrecognized model shard type: {type(model_shard)}") + + shard_state_dict = {k: v.cpu() for k, v in shard_state_dict.items()} + total_shard_size = sum( + weight.numel() * weight.element_size() for weight in shard_state_dict.values() + ) + + num_shards = runtime.world_size + idx = runtime.global_rank + + out_path = Path(out_path) + shard_file = out_path.with_stem(f"{out_path.stem}-{idx + 1:05d}-of-{num_shards:05d}") + + shard_metadata = { + "total_shard_size": total_shard_size, + "shard_keys": list(shard_state_dict.keys()), + "shard_file": str(shard_file), + } + + if runtime.is_main_process: + shard_metadatas = [{} for _ in range(runtime.world_size)] + torch.distributed.gather_object(shard_metadata, shard_metadatas, dst=0) + total_size = sum(x["total_shard_size"] for x in shard_metadatas) + metadata = {"total_size": total_size} + weight_map: dict[str, str] = {} + for shard_metadata in shard_metadatas: + weight_map.update( + {k: Path(shard_metadata["shard_file"]).name for k in shard_metadata["shard_keys"]} + ) + + index = {"metadata": metadata, "weight_map": weight_map} + index_path = Path(str(out_path) + ".index.json") + index_path.write_text(json.dumps(index, indent=2)) + + else: + torch.distributed.gather_object(shard_metadata, dst=0) + + if out_path.suffix == ".safetensors": + safe_save_file(shard_state_dict, shard_file, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, shard_file) + + runtime.wait_for_everyone() + + +def save_sharded_state_dict( + state_dict: dict[str, torch.Tensor], + save_directory: str | Path, + max_shard_size: str = "10GB", +) -> None: + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + state_dict = {k: v.cpu() for k, v in state_dict.items()} + + state_dict_split = split_torch_state_dict_into_shards(state_dict, max_shard_size=max_shard_size) + + for shard_filename, param_names in tqdm( + state_dict_split.filename_to_tensors.items(), desc="saving sharded state dict" + ): + shard_path = save_directory / shard_filename + shard = {param_name: state_dict[param_name] for param_name in param_names} + safe_save_file(shard, shard_path, metadata={"format": "pt"}) + + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + index_path = save_directory / SAFE_WEIGHTS_INDEX_NAME + index_path.write_text(json.dumps(index, indent=2)) + + +def load_sharded_state_dict( + model_name_or_path: str | Path, + keys_to_load: Iterable[str] | None = None, + device: torch.device | str = "cpu", +) -> dict[str, torch.Tensor]: + """ + keys_to_load: entire state_dict if None, else partial state_dict containing only these keys + """ + shard_paths = _resolve_shard_paths(model_name_or_path) + # print(f"shard_paths: {shard_paths}") + partial_state_dict = {} + for safetensors_path in shard_paths: + if keys_to_load is None: + shard = safe_load_file(safetensors_path) + partial_state_dict.update(shard) + else: + with safe_open(safetensors_path, framework="pt", device=str(device)) as f: + for key in f: + if key in keys_to_load: + partial_state_dict[key] = f.get_tensor(key) + return partial_state_dict + + +def _resolve_shard_paths(model_name_or_path: str) -> list[str]: + try: + unsharded_path = cached_file(model_name_or_path, SAFE_WEIGHTS_NAME) + return [unsharded_path] + except OSError: + index_path = cached_file(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + shard_paths, _ = get_checkpoint_shard_files(model_name_or_path, index_path) + return shard_paths + + +def is_in_safetensors_format(checkpoint_dir: Path) -> bool: + return len(list(checkpoint_dir.glob("*.safetensors"))) > 0 + + +def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]: + shard_paths = _resolve_shard_paths(model_name_or_path) + state_dict_shapes = {} + for safetensors_path in shard_paths: + with safe_open(safetensors_path, framework="pt") as f: + for key in f: + state_dict_shapes[key] = tuple(f.get_tensor(key).shape) + return state_dict_shapes diff --git a/tests/experimental/torch/_compress/compress_test_utils.py b/tests/experimental/torch/_compress/compress_test_utils.py index f0704f6c89..1600989225 100644 --- a/tests/experimental/torch/_compress/compress_test_utils.py +++ b/tests/experimental/torch/_compress/compress_test_utils.py @@ -19,9 +19,10 @@ import torch from datasets import Dataset, DatasetDict -from puzzle_tools.hydra_utils import register_hydra_resolvers from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers + def setup_test_model_and_data( project_root_path: Path, diff --git a/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py b/tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py similarity index 100% rename from tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py rename to tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index a38322d141..cd4e34ca1d 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + import pytest import torch import torch.distributed as dist @@ -57,3 +59,9 @@ def set_torch_dtype(request): @pytest.fixture(scope="session", autouse=True) def enable_hf_checkpointing(): mto.enable_huggingface_checkpointing() + + +@pytest.fixture +def project_root_path(request: pytest.FixtureRequest) -> Path: + """Fixture providing the project root path for tests.""" + return Path(request.config.rootpath) From 0868f1caf6f1e02dea5b625b06f5743320a4649b Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 14 Nov 2025 20:40:30 +0100 Subject: [PATCH 11/62] Add integration test for attention pruning (#562) ## What does this PR do? Make score pruning activations self-contained (no dependency on internal Nvidia code) - 1/6 - add integration test for attention pruning (pre-step to make adding activation scoring safer) --------- Signed-off-by: Daniel Korzekwa --- .../torch/_compress/compress_test_utils.py | 10 +- .../_compress/nas/plugins/test_nas_convert.py | 81 ++++++++++++- .../_compress/nas/plugins/test_nas_search.py | 8 +- .../configs/Llama-3_1-8B-attn-pruning.yaml | 108 ++++++++++++++++++ ...-8B.yaml => Llama-3_1-8B-ffn-pruning.yaml} | 0 .../torch/_compress/test_compress.py | 8 +- 6 files changed, 197 insertions(+), 18 deletions(-) create mode 100644 tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml rename tests/experimental/torch/_compress/resources/configs/{Llama-3_1-8B.yaml => Llama-3_1-8B-ffn-pruning.yaml} (100%) diff --git a/tests/experimental/torch/_compress/compress_test_utils.py b/tests/experimental/torch/_compress/compress_test_utils.py index 1600989225..ce22e1864c 100644 --- a/tests/experimental/torch/_compress/compress_test_utils.py +++ b/tests/experimental/torch/_compress/compress_test_utils.py @@ -33,8 +33,6 @@ def setup_test_model_and_data( Path, Path, Path, - Path, - str, ]: """ Setup the test model and data for the compress NAS search. @@ -46,8 +44,8 @@ def setup_test_model_and_data( runtime: the runtime to use for the test Returns: - tuple[Path, Path, Path, Path, str]: - the puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name + tuple[Path, Path, Path]: + the puzzle_dir, llama_checkpoint_path, dataset_path """ # Register Hydra custom resolvers (needed for config resolution) @@ -58,8 +56,6 @@ def setup_test_model_and_data( puzzle_dir = tmp_path llama_checkpoint_path = puzzle_dir / "input_model/llama" dataset_path = puzzle_dir / "dummy_dataset" - hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B" if rank == 0: # Setup puzzle_dir and dataset @@ -77,8 +73,6 @@ def setup_test_model_and_data( puzzle_dir, llama_checkpoint_path, dataset_path, - hydra_config_dir, - hydra_config_name, ) diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py index 47ff2531da..cf284cfc87 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py @@ -31,24 +31,28 @@ # See tests/experimental/torch/_compress/test_compress.py for instructions on how to run this test # TODO: Remove those instructions once this test runs automatically on CI # -def test_nas_convert(project_root_path: Path, tmp_path: Path): +def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), - job=partial(_test_nas_convert_multiprocess_job, project_root_path, tmp_path), + job=partial(_test_nas_convert_ffn_pruning_multiprocess_job, project_root_path, tmp_path), backend="nccl", ) -def _test_nas_convert_multiprocess_job( +def _test_nas_convert_ffn_pruning_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): with NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) as runtime: # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( - setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, runtime ) + hydra_config_dir = ( + project_root_path / "tests/experimental/torch/_compress/resources/configs" + ) + hydra_config_name = "Llama-3_1-8B-ffn-pruning" # # Run the mnt.convert() step @@ -86,4 +90,69 @@ def _test_nas_convert_multiprocess_job( runtime.wait_for_everyone() - print("PYTEST SUMMARY: test_nas_convert() test has finished successfully") + print("PYTEST SUMMARY: test_nas_convert_ffn_pruning() test has finished successfully") + + +def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_convert_attn_pruning_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_convert_attn_pruning_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + with NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) as runtime: + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, runtime + ) + hydra_config_dir = ( + project_root_path / "tests/experimental/torch/_compress/resources/configs" + ) + hydra_config_name = "Llama-3_1-8B-attn-pruning" + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/attn_independent_kv_head_contribution/" + f"100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/n_heads_in_group8").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group16").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() + + runtime.wait_for_everyone() + + print("PYTEST SUMMARY: test_nas_convert_attn_pruning() test has finished successfully") diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py index df3c1e4856..4a6a3eccec 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py @@ -45,9 +45,13 @@ def _test_nas_search_multiprocess_job( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) as runtime: # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( - setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, runtime ) + hydra_config_dir = ( + project_root_path / "tests/experimental/torch/_compress/resources/configs" + ) + hydra_config_name = "Llama-3_1-8B-ffn-pruning" # # Run the mnt.convert() step diff --git a/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml new file mode 100644 index 0000000000..21a3486f09 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml @@ -0,0 +1,108 @@ +defaults: + - pruning: attn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml rename to tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 96af36b5e9..76407bc1f0 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -61,9 +61,13 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) as runtime: # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( - setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, runtime ) + hydra_config_dir = ( + project_root_path / "tests/experimental/torch/_compress/resources/configs" + ) + hydra_config_name = "Llama-3_1-8B-ffn-pruning" # Convert the Llama model to DeciLM model. if rank == 0: From 1dde209167dbf677269f429c8c7f144eb6b8a3a8 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 18 Nov 2025 19:31:10 +0100 Subject: [PATCH 12/62] Add score_pruning_activations (step 2/6) (#563) ## What does this PR do? - Add score_pruning_activations.py Notes: - validate_model.py still depends on Nvidia internal code (will be changed in the subsequent MR) - sharded_checkpoint_utils.py - for now it needs to use DeciLM from internal Nvidia code, to be changed in the next MR --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Daniel Korzekwa Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../score_pruning_activations.py | 174 ++++++++++ .../nas/plugins/compress_nas_plugin.py | 2 +- .../tools/sharded_checkpoint_utils.py | 8 +- .../torch/_compress/tools/validate_model.py | 297 ++++++++++++++++++ modelopt/torch/_compress/utils/dist_utils.py | 30 ++ modelopt/torch/_compress/utils/utils.py | 62 ++++ 6 files changed, 568 insertions(+), 5 deletions(-) create mode 100644 modelopt/torch/_compress/activation_scoring/score_pruning_activations.py create mode 100644 modelopt/torch/_compress/tools/validate_model.py create mode 100644 modelopt/torch/_compress/utils/dist_utils.py create mode 100644 modelopt/torch/_compress/utils/utils.py diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py new file mode 100644 index 0000000000..3617bdb1c2 --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig +from utils.parsing import format_global_config + +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime +from modelopt.torch._compress.tools.validate_model import validate_model +from modelopt.torch._compress.utils.dist_utils import is_distributed + + +def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: + """ + Determine if the activation hook method has proper checkpoint support implemented. + + Args: + activation_hooks_kwargs: Hook configuration + + Returns: + bool: True if the hook method has save_state/load_state implemented + """ + method = activation_hooks_kwargs.get("method", "") + + # Methods with implemented checkpoint support + supported_methods = { + "iterative", # IterativeChannelContributionHook: save_state/load_state implemented + "independent", # IndependentChannelContributionHook: save_state/load_state implemented + "stats", # RouterStatsHook: save_state/load_state implemented + "ranked_choice_voting", # RankedChoiceVotingHook: save_state/load_state implemented + } + + return method in supported_methods + + +def check_scoring_completion( + activations_log_dir: str, runtime, activation_hooks_kwargs=None +) -> bool: + """ + Check if scoring is already completed by looking for the expected output files. + Also checks if the scoring method is safe for resume. + + Args: + activations_log_dir: Directory where activation logs should be stored + runtime: Runtime object for distributed processing + activation_hooks_kwargs: Hook configuration to check if resume is safe + + Returns: + bool: True if scoring is completed (has rank files and args.json) + """ + # Only check completion on main process (or if no distributed runtime) + if runtime is None or runtime.is_main_process: + log_dir = Path(activations_log_dir) + + # Check if directory exists + if not log_dir.exists(): + return False + + # Check for rank files (at least rank_0.pth should exist) + rank_files = list(log_dir.glob("rank_*.pth")) + + if not rank_files: + return False + + # Check for args.json (created by main process) + args_file = log_dir / "args.json" + has_args_json = args_file.exists() + + # Check for completion: if we have rank files and args.json, scoring is complete + if rank_files and has_args_json: + # Add optional completion info for debugging + mprint(f"Found completed scoring in {activations_log_dir}") + mprint(f" - Found {len(rank_files)} rank files") + mprint(f" - Found args.json: {has_args_json}") + + return True + + return False + + +def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: + """ + Determine if we should skip scoring entirely (only if 100% complete). + Partial progress should proceed to validate_model for proper resume. + + Args: + cfg: Configuration object + runtime: Runtime object for distributed processing + + Returns: + bool: True if we should skip scoring (100% completed), False if we should run/resume it + """ + # Check if activations_log_dir is specified + if not hasattr(cfg.pruning, "activations_log_dir") or cfg.pruning.activations_log_dir is None: + mprint("No activations_log_dir specified, running scoring") + return False + + # Check for force restart flag + force_restart = getattr(cfg.pruning, "force_restart_scoring", False) + if force_restart: + mprint("Force restart flag set, will restart scoring regardless of existing artifacts") + return False + + # Get hook configuration to check if resume is mathematically safe + activation_hooks_kwargs = getattr(cfg.pruning, "activation_hooks_kwargs", {}) + + # Check if scoring is already completed + is_completed = check_scoring_completion( + cfg.pruning.activations_log_dir, runtime, activation_hooks_kwargs + ) + + # Broadcast the result to all processes in distributed mode + if runtime is not None and runtime.world_size > 1: + should_skip = [is_completed] # Use list for mutable object + torch.distributed.broadcast_object_list(should_skip, src=0) + is_completed = should_skip[0] + + if is_completed: + mprint("Scoring 100% completed, skipping...") + + return is_completed + + +# Old progress tracking removed - checkpoint manager handles all progress tracking + + +def launch_score_activations(cfg: DictConfig, runtime): + # Check if we should skip scoring entirely (only if 100% complete) + if should_skip_scoring_completely(cfg, runtime): + return + + mprint("Starting pruning activation scoring...") + + # The checkpoint manager inside validate_model handles all progress tracking + validate_model(args=cfg.pruning, runtime=runtime) + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(format_global_config(cfg, title="Score Pruning Activations")) + + _runtime = ( + NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") + ) + if is_distributed() + else BaseRuntime(dtype=torch.bfloat16) + ) + with _runtime as runtime: + launch_score_activations(cfg, runtime) + runtime.wait_for_everyone() + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 13d418b69d..84af06b137 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -26,11 +26,11 @@ import build_library_and_stats import mip_and_realize_models import pruning_ckpts -import score_pruning_activations import scoring import torch from torch import nn +from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 91fcb5ebd5..a27cd50771 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -29,6 +29,7 @@ import torch.distributed import torch.nn as nn from huggingface_hub import split_torch_state_dict_into_shards +from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file @@ -36,17 +37,16 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override -from utils.utils import EmptyInitOnDevice from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, - DeciLMForCausalLM, rope_type_to_class, ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.runtime import IRuntime +from modelopt.torch._compress.utils.utils import EmptyInitOnDevice class DummyModule(nn.Module): @@ -392,7 +392,7 @@ def load_sharded_state_dict( partial_state_dict.update(shard) else: with safe_open(safetensors_path, framework="pt", device=str(device)) as f: - for key in f: + for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable if key in keys_to_load: partial_state_dict[key] = f.get_tensor(key) return partial_state_dict @@ -417,6 +417,6 @@ def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]: state_dict_shapes = {} for safetensors_path in shard_paths: with safe_open(safetensors_path, framework="pt") as f: - for key in f: + for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable state_dict_shapes[key] = tuple(f.get_tensor(key).shape) return state_dict_shapes diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py new file mode 100644 index 0000000000..e264ea6813 --- /dev/null +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import textwrap +from pathlib import Path + +import torch.distributed +from omegaconf import DictConfig +from torch import nn +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from utils.activation_hooks.utils import register_activation_hooks +from utils.data.dataloaders import create_validation_dataloader +from utils.parsing import simple_parse_args_string +from utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline +from utils.validation import calculate_losses + +from modelopt.torch._compress.tools.checkpoint_utils_hf import load_checkpoint +from modelopt.torch._compress.tools.logger import aprint, mprint +from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime +from modelopt.torch._compress.tools.sharded_checkpoint_utils import load_and_shard_model + +# #TODO:Import slack from root utils directory +# root_path = os.path.join(os.path.dirname(__file__), "..", "..") +# if root_path not in sys.path: +# sys.path.append(root_path) +# from utils.slack import send_slack_message + +""" +Two goals: +1) Calculate lm loss and token accuracy for a model. +May raise lots of NCCL warnings when it finishes, don't be alarmed. +Can be used to validate a HuggingFace model. +Automatically uses pipeline parallelism via device_map="auto". + +2) Register hooks to capture the inputs and the outputs of pytorch modules. +For example, to collect activations scores for various layers (ffn, layer_norm, etc.) +that are used for pruning (ffn_hidden_size, embedding_pruning, etc). +See --activations_log_dir and --activation_hooks_kwargs args arguments. + +""" + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + type=str, + default=None, + help="Required unless a model is passed to the function", + ) + parser.add_argument("--dataset_path", type=str, required=True) + + parser.add_argument("--output_dir_name", type=str, default="validation") + parser.add_argument( + "--calculate_full_score_ablations", + action="store_true", + help="Calculates a diverse suite of teacher similarity scores. " + "By default only a small suite is calculated, which is good for most use-cases.", + ) + + parser.add_argument("--tokenizer_name", type=str, default=None) + parser.add_argument("--data_column", type=str, default="content") + # TODO: Add help text for FIM rate, also for others less obvious args + parser.add_argument("--fim_rate", type=float, default=0) + parser.add_argument("--fim_spm_rate", type=float, default=0) + parser.add_argument("--eval_samples", type=int, default=None) + parser.add_argument("--block_size", type=int, default=4096) + parser.add_argument("--micro_batch_size", type=int, default=4) + parser.add_argument("--val_dataset_name", type=str, default="__auto__") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--source_datasets_to_discard", nargs="+", type=str) + parser.add_argument("--bos_rate", type=float, default=1.0) + parser.add_argument("--shuffle_seed", type=int, default=None) + parser.add_argument("--varlen", action="store_true") + parser.add_argument("--pipeline_parallel", action="store_true") + parser.add_argument("--write_results", action="store_true") + parser.add_argument("--activations_log_dir", type=str, default=None) + parser.add_argument( + "--activation_hooks_kwargs", + type=str, + default=None, + help="Comma separated string arguments, e.g. `arg1=val1,arg2=val2`", + ) + parser.add_argument( + "--calc_losses_on_cpu", + action="store_true", + help="Very slow, not recommended. Can help avoid OOM.", + ) + return parser + + +def parse_args() -> argparse.Namespace: + parser = build_arg_parser() + args, unknown_args = parser.parse_known_args() + return args + + +@torch.no_grad() +def validate_model( + args: argparse.Namespace | DictConfig, + model: PreTrainedModel | None = None, + tokenizer: PreTrainedTokenizerBase | None = None, + target_hidden_states_per_batch: list[torch.Tensor] | None = None, + return_hidden_states: bool = False, + runtime: IRuntime | None = None, + calculate_full_score_ablations: bool = False, + val_dataloader: DataLoader | None = None, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + if val_dataloader is None: + val_dataloader = ( + prepare_dataloader(args, tokenizer) + if (runtime is None or runtime.is_main_process) + else None + ) + validation_full_iters = ( + args.eval_samples // args.micro_batch_size + ) # model pipeline, single data rank + + model = prepare_model(args, model, runtime) + + just_model_forward = False + checkpoint_manager = None + activation_hooks = None + + if args.activations_log_dir is not None: + activation_hooks_kwargs = ( + simple_parse_args_string(args.activation_hooks_kwargs) + if isinstance(args.activation_hooks_kwargs, str) + else args.activation_hooks_kwargs + ) + activation_hooks_kwargs["validation_full_iters"] = validation_full_iters + + # Create activation hooks first + activation_hooks, hook_class = register_activation_hooks( + model=model, activation_hooks_kwargs=activation_hooks_kwargs + ) + + # Create checkpoint manager with hooks + from utils.checkpoint_manager import ScoringCheckpointManager + + mprint( + f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}" + ) + checkpoint_manager = ScoringCheckpointManager( + checkpoint_dir=args.activations_log_dir, + runtime=runtime, + activation_hooks=activation_hooks, + checkpoint_interval=50, # Save every 50 batches + ) + + # Load existing checkpoint if available + mprint("Attempting to load existing checkpoint...") + checkpoint_data = checkpoint_manager.load_checkpoint() + if checkpoint_data: + mprint(f"Checkpoint loaded successfully: {checkpoint_data}") + else: + mprint("No checkpoint found, starting fresh") + just_model_forward = True + model.lm_head = nn.Identity() + + if runtime is None: + losses, hidden_states_per_batch = calculate_losses( + model=model, + dataloader=val_dataloader, + checkpoint_manager=checkpoint_manager, + ) + else: + losses, hidden_states_per_batch = calculate_losses_pipeline( + runtime=runtime, + stitched_model=model, + dataloader=val_dataloader, + target_hidden_states_per_batch=target_hidden_states_per_batch, + return_hidden_states=return_hidden_states, + calculate_full_score_ablations=calculate_full_score_ablations, + calc_on_cpu=args.calc_losses_on_cpu, + just_model_forward=just_model_forward, + checkpoint_manager=checkpoint_manager, + ) + + if losses is not None: + avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()} + + results_str = f""" + validate_model: + {args.model_name_or_path=} + Average losses = {avg_losses} + Actual num samples = {len(next(iter(losses.values()))["per_sample"])} + {args=} + """ + results_str = textwrap.dedent(results_str) + aprint(results_str) + if args.write_results: + Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str) + # TODO: send_slack_message(results_str) + + if args.activations_log_dir is not None: + hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args, runtime) + + return losses, hidden_states_per_batch + + +def prepare_model( + args: argparse.Namespace, + model: PreTrainedModel | None = None, + runtime: IRuntime | None = None, +) -> nn.Module: + if model is None: + assert args.model_name_or_path is not None + if runtime is not None: + model = load_and_shard_model( + runtime, + args.model_name_or_path, + model_config_overrides={"block_size": args.block_size}, + ) + else: + try: + model = load_checkpoint( + args.model_name_or_path, + model_config_overrides={"block_size": args.block_size}, + ignore_unexpected_config_keys=True, + ) + model.to("cuda") + except FileNotFoundError: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + torch_dtype="auto", + device_map="auto", + trust_remote_code=True, + ) + + model.eval() + return model + + +def prepare_dataloader( + args: argparse.Namespace, + tokenizer: PreTrainedTokenizerBase | None = None, +) -> DataLoader: + if tokenizer is None: + tokenizer_name = getattr(args, "tokenizer_name", None) + assert (tokenizer_name is not None) or (args.model_name_or_path is not None) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name or args.model_name_or_path, trust_remote_code=True + ) + + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=args.seed, + tokenizer=tokenizer, + block_size=args.block_size, + dataset=args.dataset_path, + content_field=args.data_column, + fim_rate=args.fim_rate, + fim_spm_rate=args.fim_spm_rate, + micro_batch_size=args.micro_batch_size, + eval_samples=args.eval_samples, + dataset_name=args.val_dataset_name, + source_datasets_to_discard=args.source_datasets_to_discard, + bos_rate=args.bos_rate, + varlen=args.varlen, + shuffle_seed=args.shuffle_seed, + load_dataset_fn=args.load_dataset_fn, + ) + + return val_dataloader + + +def main(): + args = parse_args() + if args.pipeline_parallel: + with NativeDdpRuntime(dtype=torch.bfloat16) as runtime: + validate_model(args=args, runtime=runtime) + else: + validate_model(args=args, runtime=None) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/_compress/utils/dist_utils.py b/modelopt/torch/_compress/utils/dist_utils.py new file mode 100644 index 0000000000..84f8f2bab1 --- /dev/null +++ b/modelopt/torch/_compress/utils/dist_utils.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch.distributed as dist + + +def is_distributed(): + """ + From torchtune.utils.is_distributed() : https://docs.pytorch.org/torchtune/0.2/generated/torchtune.utils.is_distributed.html + """ + port = os.environ.get("MASTER_PORT", "") + addr = os.environ.get("MASTER_ADDR", "") + size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", -1)) + avlb = dist.is_available() + return bool(port and addr and size > 1 and rank >= 0 and avlb) diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py new file mode 100644 index 0000000000..ef952dfec6 --- /dev/null +++ b/modelopt/torch/_compress/utils/utils.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + def __init__(self, device=None, dtype=None): + """ + Create tensors with given device and dtype and don't run initialization + (but instead use "empty tensors", i.e. uninitialized memory). + + device: `torch.device` to work with + dtype: `torch.dtype` to work with + + + Example:: + with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): + model = LLaMA(model_config) + model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth"))""" + + self.device = device + self.dtype = dtype + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, "__module__", None) == "torch.nn.init": + if "tensor" in kwargs: + return kwargs["tensor"] + else: + return args[0] + if ( + self.device is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("device") is None + ): + kwargs["device"] = self.device + if ( + self.dtype is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("dtype") is None + ): + kwargs["dtype"] = self.dtype + return func(*args, **kwargs) From 2e559e7fbf8ca3a1e90f24ab1b17819e4f4ee5ce Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:43:58 +0530 Subject: [PATCH 13/62] Update README.md Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/compress/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/compress/README.md b/examples/compress/README.md index a4881150d0..d4c749f473 100644 --- a/examples/compress/README.md +++ b/examples/compress/README.md @@ -1,6 +1,5 @@ # Compress Algorithm Tutorial -This tutorial demonstrates how to compress large language models using the Compress algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). This tutorial demonstrates how to compress large language models using the compress algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). The goal of the algorithm it to find the most optimal modifications to MLP and attention layers of the model, resulting in a heterogeneous model architecture. The supported modifications are: From f10be0d4f98c8f1764d1e3971785fcf616f12697 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 14:08:24 +0100 Subject: [PATCH 14/62] Add activation hooks used for pruning (#576) ## What does this PR do? Add activation hooks used for pruning --------- Signed-off-by: Daniel Korzekwa --- examples/compress/README.md | 2 +- .../activation_hooks/__init__.py | 15 + .../activation_hooks/hooks.py | 562 ++++++++++++++++++ .../activation_hooks/utils.py | 91 +++ .../torch/_compress/tools/validate_model.py | 11 +- modelopt/torch/_compress/utils/utils.py | 1 - 6 files changed, 678 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py create mode 100644 modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py create mode 100644 modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py diff --git a/examples/compress/README.md b/examples/compress/README.md index d4c749f473..3bd218aa48 100644 --- a/examples/compress/README.md +++ b/examples/compress/README.md @@ -29,7 +29,7 @@ pip install -e .[hf,compress] How to choose `intermediate_size_list`? The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a 256) to avoid issues with tensor operations in subsequent steps. - Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` GiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements. + Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` MiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements. 2. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py new file mode 100644 index 0000000000..6339d55ab6 --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py @@ -0,0 +1,562 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Provides hooks for capturing the inputs and the outputs of pytorch modules that are used for +activation scoring for pruning. +""" + +import argparse +import gc +import json +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path + +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, OmegaConf +from torch import nn + +# BlockConfig used at runtime, not just type hints (lines 680, 790) +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import ( + DeciLMConfig, # noqa: TC001 +) +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMRMSNorm +from modelopt.torch._compress.tools.logger import aprint +from modelopt.torch._compress.tools.robust_json import json_dump +from modelopt.torch._compress.tools.runtime import IRuntime + + +def clear_gpu_memory(clear: bool) -> None: + if clear: + gc.collect() + torch.cuda.empty_cache() + + +class ActivationsHook(ABC): + @abstractmethod + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + A hook to be registered in pytorch modules: torch.nn.Module.register_forward_hook() + + Args: + module (nn.Module): + args (tuple[torch.Tensor]): Input of the pytorch module + output (torch.Tensor): Output of the pytorch module + """ + ... + + @abstractmethod + def to_dict(self) -> dict[str, torch.Tensor]: ... + + def save_state(self) -> dict: + """ + Save the internal state of the hook for checkpointing. + + Returns: + dict: State dictionary that can be used to restore the hook's state + """ + # Default implementation - hooks should override this if they have state to save + return {} + + def load_state(self, state_dict: dict) -> None: + """ + Load the internal state of the hook from a checkpoint. + + Args: + state_dict: State dictionary previously returned by save_state() + """ + # Default implementation - hooks should override this if they have state to load + + def get_progress_info(self) -> dict: + """ + Get progress information for this hook (e.g., current iteration, samples processed). + + Returns: + dict: Progress information + """ + # Default implementation - hooks can override to provide progress info + return {} + + @classmethod + def dump_activations_logs( + cls: type["ActivationsHook"], + activation_hooks: dict[str, "ActivationsHook"], + activations_log_dir: Path | str, + args: argparse.Namespace, + runtime: IRuntime | None, + ): + """ + Default implementation for dumping final activation scores logs to disk. + This is called only at the end of scoring to save final results. + """ + + activations_log_dir = Path(activations_log_dir) + activations_log_dir.mkdir(exist_ok=True, parents=True) + rank = runtime.global_rank if runtime is not None else 0 + activations_log_path = activations_log_dir / f"rank_{rank}.pth" + activations_log = { + module_name: hook.to_dict() for module_name, hook in activation_hooks.items() + } + torch.save(activations_log, activations_log_path) + + if rank == 0: + args.activation_hooks_kwargs.pop("model") + json_dump( + OmegaConf.to_container(args, resolve=True) + if isinstance(args, DictConfig) + else vars(args), + activations_log_dir / "args.json", + ) + if runtime is not None: + runtime.wait_for_everyone() # rank 0 will not wait before dumping args.json + + aprint(f"Dumped final activations log to {activations_log_path}") + + @classmethod + def save_hook_states( + cls: type["ActivationsHook"], + activation_hooks: dict[str, "ActivationsHook"], + activations_log_dir: Path | str, + runtime: IRuntime | None, + ): + """ + Save hook states for checkpointing (separate from final results). + This can be called periodically during scoring. + Note: Synchronization should be handled at a higher level to avoid deadlocks. + """ + activations_log_dir = Path(activations_log_dir) + activations_log_dir.mkdir(exist_ok=True, parents=True) + rank = runtime.global_rank if runtime is not None else 0 + + hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth" + hook_states = { + module_name: hook.save_state() for module_name, hook in activation_hooks.items() + } + torch.save(hook_states, hook_states_path) + + return hook_states_path + + +class IndependentChannelContributionHook(ActivationsHook): + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + weight_matrix = linear_layer.weight.float() + self.weight_norm = torch.linalg.vector_norm(weight_matrix, dim=0) + num_channels = linear_layer.in_features + self.agg_channel_activations = torch.zeros( + size=(num_channels,), dtype=torch.float32, device=weight_matrix.device + ) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + activations = args[0] + mean_abs_channel_activations = ( + activations.abs().float().mean(dim=list(range(activations.ndim - 1))) + ) + self.agg_channel_activations[:] += mean_abs_channel_activations # shape [I] + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "score": (self.weight_norm * self.agg_channel_activations).cpu(), + "weight_norm": self.weight_norm.cpu(), + "agg_channel_activations": self.agg_channel_activations.cpu(), + } + + def save_state(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "agg_channel_activations": self.agg_channel_activations.cpu().clone(), + "weight_norm": self.weight_norm.cpu().clone(), + } + + def load_state(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.agg_channel_activations = state_dict["agg_channel_activations"].to( + self.agg_channel_activations.device + ) + # weight_norm should be the same as it's derived from the model weights + # but we can verify it matches + expected_weight_norm = state_dict["weight_norm"].to(self.weight_norm.device) + if not torch.allclose(self.weight_norm, expected_weight_norm, rtol=1e-5): + print( + "Warning: weight_norm mismatch during state loading - model weights may have changed" + ) + + +def get_pruning_schedule(num_channels, pruning_iters): + """ + Spending decreases monotonically when num_channels >= pruning_iters. + Intervals between spends increase monotonically when pruning_iters > num_channels. + The budget is fully utilized, and there's spending in the last iteration. + num_channels = 10, pruning_iters = 4 ==> [3, 3, 2, 2] + num_channels = 4, pruning_iters = 10 ==> [0, 1, 0, 1, 0, 0, 1, 0, 0, 1] + """ + if num_channels >= pruning_iters: + # Case when budget is greater than or equal to iterations + q = num_channels // pruning_iters # Base spend per iteration + r = num_channels % pruning_iters # Remainder to distribute + + schedule = [] + for i in range(pruning_iters): + if i < r: + # Assign higher spend to earlier iterations + schedule.append(q + 1) + else: + schedule.append(q) + else: + # Case when iterations are greater than budget + schedule = [0] * pruning_iters + for i in range(1, num_channels + 1): + # Distribute spends at positions where intervals increase monotonically + pos = ((i * pruning_iters) // num_channels) - 1 + schedule[pos] = 1 + return schedule + + +class IterativeChannelContributionHook(ActivationsHook): + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + """TODO: Add docstring. + + Args: + linear_layer: The linear projection layer + activation_hooks_kwargs: The activation hooks kwargs + """ + self.weight_matrix = linear_layer.weight + self.num_channels = linear_layer.in_features + self.pruning_iters = activation_hooks_kwargs["validation_full_iters"] + self.clear_gpu_memory = activation_hooks_kwargs.get("clear_gpu_memory", False) + self.curr_iter = 0 + self.pruning_schedule = get_pruning_schedule( + num_channels=self.num_channels, pruning_iters=self.pruning_iters + ) + + self.agg_cont_per_channel = torch.zeros( + size=(self.num_channels,), + dtype=torch.float32, + device=self.weight_matrix.device, + ) + self.pruned_channels = [] + self.calibration_method = activation_hooks_kwargs.get("calibration_method") + self.epsilon = 1e-8 + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + activations = args[0] + n_channels_to_prune = self.pruning_schedule[self.curr_iter] + + curr_activations = activations.clone() # Shape B,T,I + curr_activations[..., self.pruned_channels] = 0 + output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E + + if self.calibration_method is None: + scaling_factor_per_token = torch.ones_like(output[..., 0]) # Shape B,T + elif self.calibration_method == "scale_by_magnitude": + output_norms = torch.linalg.vector_norm(output, dim=-1) # Shape B,T + output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T + scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon) + del output_curr_norms, output_norms + else: + raise NotImplementedError + del curr_activations + clear_gpu_memory(clear=self.clear_gpu_memory) + + s = scaling_factor_per_token.unsqueeze(-1) * output - output_curr # Shape: (B, T, E) + s_squared_per_token = torch.sum(s**2, dim=-1) # Shape: (B, T) + b = s @ self.weight_matrix # Shape: (B, T, I) + c = torch.sum(self.weight_matrix**2, dim=0) # Shape: (I) + del s, output_curr + clear_gpu_memory(clear=self.clear_gpu_memory) + + contribution_squared = ( + s_squared_per_token.unsqueeze(2) + 2 * activations * b + (activations**2) * c + ) # Shape: (B, T, I) + del s_squared_per_token, b, c, activations + clear_gpu_memory(clear=self.clear_gpu_memory) + + contribution = torch.sqrt(contribution_squared + self.epsilon) # Shape: (B, T, I) + mean_cont_per_channel = torch.mean(contribution, dim=(0, 1)) # Shape: (I) + mean_cont_per_channel[self.pruned_channels] = torch.inf + del contribution, contribution_squared + clear_gpu_memory(clear=self.clear_gpu_memory) + + if n_channels_to_prune == 0: + self.agg_cont_per_channel += mean_cont_per_channel + else: + _, worst_indices = torch.topk(mean_cont_per_channel, n_channels_to_prune, largest=False) + worst_indices_list = worst_indices.tolist() + assert not set(self.pruned_channels).intersection(set(worst_indices_list)) + self.pruned_channels.extend(worst_indices_list) + self.agg_cont_per_channel.zero_() + self.curr_iter += 1 + + def to_dict(self) -> dict[str, torch.Tensor]: + assert self.num_channels == len(self.pruned_channels) + channels_importance_ascending = torch.tensor(self.pruned_channels, dtype=torch.long) + score = torch.empty(self.num_channels, dtype=torch.long) + score[channels_importance_ascending] = torch.arange(self.num_channels, dtype=torch.long) + + return { + "score": score.cpu(), + "channels_importance_ascending": channels_importance_ascending.cpu(), + } + + def save_state(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "curr_iter": self.curr_iter, + "pruned_channels": self.pruned_channels.copy(), + "agg_cont_per_channel": self.agg_cont_per_channel.cpu().clone(), + "num_channels": self.num_channels, + "pruning_iters": self.pruning_iters, + "pruning_schedule": self.pruning_schedule.copy(), + "calibration_method": self.calibration_method, + "epsilon": self.epsilon, + } + + def load_state(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.curr_iter = state_dict["curr_iter"] + self.pruned_channels = state_dict["pruned_channels"].copy() + self.agg_cont_per_channel = state_dict["agg_cont_per_channel"].to(self.weight_matrix.device) + # Verify other parameters match + assert self.num_channels == state_dict["num_channels"], "Channel count mismatch" + assert self.pruning_iters == state_dict["pruning_iters"], "Iteration count mismatch" + assert self.pruning_schedule == state_dict["pruning_schedule"], "Pruning schedule mismatch" + + def get_progress_info(self) -> dict: + """Get progress information.""" + progress = self.curr_iter / self.pruning_iters if self.pruning_iters > 0 else 0.0 + return { + "curr_iter": self.curr_iter, + "total_iters": self.pruning_iters, + "progress": progress, + "pruned_channels_count": len(self.pruned_channels), + "total_channels": self.num_channels, + } + + +class IndependentKvHeadContributionHook(ActivationsHook): + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + """TODO: Add docstring. + + Args: + linear_layer: The linear projection layer + activation_hooks_kwargs: The activation hooks kwargs + """ + model_config: DeciLMConfig = activation_hooks_kwargs["model"].config + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + + self.optimize_for = activation_hooks_kwargs.get("optimize_for", "memory") + assert self.optimize_for in ["latency", "memory"] + + self.hidden_size = model_config.hidden_size + self.n_heads_in_group = block_config.attention.n_heads_in_group + self.num_q_heads = model_config.num_attention_heads + self.num_kv_heads = self.num_q_heads // self.n_heads_in_group + self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads) + + self.agg_kv_head_contributions = torch.zeros( + size=(self.num_kv_heads,), + dtype=torch.float32, + device=linear_layer.weight.device, + ) + + # Reshape weight matrix to group by KV heads + self.weight_grouped = linear_layer.weight.view( + self.hidden_size, self.num_kv_heads, self.head_dim * self.n_heads_in_group + ).permute((1, 0, 2)) + # weight_grouped.shape: (kv_heads, hidden_dim, head_dim * n_heads_in_group) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: The linear projection layer + :param args: tuple containing attention output tensor (B, T, num_q_heads * head_dim) + :param output: The projected output (B, T, hidden_dim) + """ + attn_out = args[0] # Shape: (B, T, num_q_heads * head_dim) + batch_size, seq_len, _ = attn_out.shape + + # Reshape attention output to group by KV heads + attn_out_grouped = attn_out.view( + batch_size, + seq_len, + self.num_kv_heads, + self.head_dim * self.n_heads_in_group, + ).unsqueeze(-2) + # attn_out_grouped.shape: (B, T, kv_heads, 1, head_dim * n_heads_in_group) + + if self.optimize_for == "latency": + # Compute contribution per KV head group + # First compute the projection for each KV head group + layer_out_grouped = attn_out_grouped @ self.weight_grouped.transpose(-1, -2) + layer_out_grouped = layer_out_grouped.squeeze(-2) + # layer_out_grouped.shape: (B, T, kv_heads, hidden_dim) + + else: + layer_out_grouped = [] + for i in range(self.num_kv_heads): + _layer_out = attn_out_grouped[:, :, i] @ self.weight_grouped[i].transpose(-1, -2) + layer_out_grouped.append(_layer_out) + layer_out_grouped = torch.cat(layer_out_grouped, dim=2) + + # Compute L2 norm of each group's contribution + contrib_per_kv_head = torch.linalg.vector_norm(layer_out_grouped, dim=-1) + # contrib_per_kv_head.shape: (B, T, kv_heads) + + contrib_per_kv_head = contrib_per_kv_head.mean(dim=(0, 1)) + # contrib_per_kv_head.shape: (kv_heads,) + + # Accumulate contributions + self.agg_kv_head_contributions += contrib_per_kv_head + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "score": self.agg_kv_head_contributions.cpu(), + } + + +class LayerNormContributionHook(ActivationsHook): + def __init__(self, layernorm_layer: DeciLMRMSNorm, activation_hooks_kwargs: dict): + """Aggregates mean absolute activation values per channel for a layer normalization layer. + + Args: + layernorm_layer: The layer normalization layer + activation_hooks_kwargs: The activation hooks kwargs (not used) + """ + self.agg_embedding_activations = torch.zeros( + size=(layernorm_layer.weight.shape[0],), + dtype=torch.float32, + device=layernorm_layer.weight.device, + ) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + self.agg_embedding_activations += ( + output.abs().float().mean(dim=list(range(output.ndim - 1))) + ) + + @classmethod + def dump_activations_logs( + cls: type["LayerNormContributionHook"], + activation_hooks: dict[str, "ActivationsHook"], + activations_log_dir: Path | str, + args: argparse.Namespace, + runtime: IRuntime | None, + ): + """ + At the end of the default implementation of dumping activation scores to disc, + save aggregated channel importance results. + """ + + super().dump_activations_logs(activation_hooks, activations_log_dir, args, runtime) + + rank = runtime.global_rank if runtime is not None else 0 + if rank == 0: + LayerNormContributionHook._save_channel_importance_results( + activation_hooks, activations_log_dir, args + ) + + runtime.wait_for_everyone() + + @staticmethod + def _save_channel_importance_results( + activation_hooks: dict[str, ActivationsHook], + activations_log_dir: Path, + args: argparse.Namespace, + ) -> None: + """ + Save channel importance results from activation hooks. + """ + + # Find all activation files (for multi-rank scenarios) + activations_log_dir = Path(activations_log_dir) + activation_files = list(activations_log_dir.glob("rank_*.pth")) + if not activation_files: + aprint(f"Warning: No activation files found in {activations_log_dir}") + return + + # Load and aggregate activation data from all ranks + all_scores = [] + for activation_file in activation_files: + aprint(f"Loading activations from {activation_file}") + activation_data = torch.load(activation_file, map_location="cpu") + + # Extract scores from the activation data + for module_name, hook_data in activation_data.items(): + if "score" in hook_data: + scores = hook_data["score"] + all_scores.append(scores) + aprint(f"Loaded {len(scores)} channel scores from {module_name}") + + if not all_scores: + aprint("Warning: No valid activation data found") + return + + # Average scores across all ranks and modules + avg_scores = torch.stack(all_scores).mean(dim=0) + aprint(f"Averaged {len(all_scores)} score sets into {len(avg_scores)} channels") + + # Create channel importance ranking (descending order) + ranked_channels = torch.argsort(avg_scores, descending=True).tolist() + + # Create output data structure + timestamp = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") + output_data = { + "model_path": getattr(args, "model_name_or_path", "unknown"), + "dataset_path": getattr(args, "dataset_path", "unknown"), + "experiment_id": getattr(args, "experiment_id", f"experiment_{timestamp}"), + "eval_samples": getattr(args, "eval_samples", 0), + "micro_batch_size": getattr(args, "micro_batch_size", 0), + "timestamp": timestamp, + "total_channels": len(ranked_channels), + "channel_importance_ranking": ranked_channels, + "channel_scores": avg_scores.tolist(), + "score_statistics": { + "min": float(avg_scores.min()), + "max": float(avg_scores.max()), + "mean": float(avg_scores.mean()), + "std": float(avg_scores.std()), + }, + } + + # Save the output + output_path = activations_log_dir / "channel_importance_results.json" + aprint(f"Saving channel importance data to {output_path}") + with open(output_path, "w") as f: + json.dump(output_data, f, indent=2) + + # Print summary statistics + aprint("=== Channel Importance Summary ===") + aprint(f"Total channels: {len(ranked_channels)}") + aprint(f"Top 10 most important channels: {ranked_channels[:10]}") + aprint(f"Bottom 10 least important channels: {ranked_channels[-10:]}") + aprint(f"Score range: {avg_scores.min():.4f} to {avg_scores.max():.4f}") + aprint(f"Score mean: {avg_scores.mean():.4f}") + aprint(f"Score std: {avg_scores.std():.4f}") + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "score": self.agg_embedding_activations.cpu(), + "channels_importance_ascending": self.agg_embedding_activations.sort()[1].cpu(), + } diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py new file mode 100644 index 0000000000..457e37d74e --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Provides a function to register activation hooks for a model. +Activation hooks are used to compute activation scores for pruning.""" + +import re + +from modelopt.torch._compress.activation_scoring.activation_hooks import hooks +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM + + +def register_activation_hooks( + model: DeciLMForCausalLM, activation_hooks_kwargs: dict +) -> tuple[dict[str, hooks.ActivationsHook], hooks.ActivationsHook]: + hook_class_map = { + "mlp.down_proj": { + "independent": hooks.IndependentChannelContributionHook, + "iterative": hooks.IterativeChannelContributionHook, + }, + "self_attn.o_proj": { + "independent_kv_head_contribution": hooks.IndependentKvHeadContributionHook, + }, + r"regex:experts\.\d+\.down_proj$": { # For MoE + "independent": hooks.IndependentChannelContributionHook, + }, + # TODO: maybe this is too generic, and we should have it specifically for + # input_layernorm and post_attention_layernorm; now it might select qk_norms + "layernorm": { + "layer_norm_contribution": hooks.LayerNormContributionHook, + }, + } + + activation_hooks = {} + target_layer = activation_hooks_kwargs.get("target_layer", "mlp.c_proj") + + if target_layer.startswith("regex:"): + target_layer_regex = target_layer[len("regex:") :] + pattern = re.compile(target_layer_regex) + + def match_predicate(module_name, module): + return pattern.search(module_name) + else: + + def match_predicate(module_name, module): + return module_name.endswith(target_layer) + + target_layer_hooks_map = hook_class_map.get(target_layer) + if target_layer_hooks_map is None: + raise ValueError(f"no hook classes found for: {target_layer}") + + hook_class = target_layer_hooks_map.get(activation_hooks_kwargs["method"]) + if hook_class is None: + raise ValueError(f"Unknown hook class: {hook_class}") + + if target_layer == "block": + pattern = re.compile(r"^transformer\.h\.\d+$") + + def match_predicate(module_name, module): + return pattern.match(module_name) + + activation_hooks_kwargs["model"] = model + for module_name, module in model.named_modules(): + if match_predicate(module_name, module): + block_config = None + if block_idx_match := re.search(r"\.(\d+)\.", module_name): + block_idx = int(block_idx_match.group(1)) + block_config = model.config.block_configs[block_idx] + curr_activation_hooks_kwargs = { + **activation_hooks_kwargs, + "block_config": block_config, + } + + hook = hook_class(module, curr_activation_hooks_kwargs) + module.register_forward_hook(hook) + activation_hooks[module_name] = hook + + return activation_hooks, hook_class diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index e264ea6813..37a49ed236 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -13,6 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Provides a function to validate a model. Runs a model forward pass on a dataset and calculates +the loss, and optionally registers hooks to capture the inputs and the outputs +of pytorch modules that are used for activation scoring for pruning.""" + import argparse import textwrap from pathlib import Path @@ -27,12 +32,14 @@ PreTrainedModel, PreTrainedTokenizerBase, ) -from utils.activation_hooks.utils import register_activation_hooks from utils.data.dataloaders import create_validation_dataloader from utils.parsing import simple_parse_args_string from utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline from utils.validation import calculate_losses +from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( + register_activation_hooks, +) from modelopt.torch._compress.tools.checkpoint_utils_hf import load_checkpoint from modelopt.torch._compress.tools.logger import aprint, mprint from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime @@ -212,7 +219,7 @@ def validate_model( Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str) # TODO: send_slack_message(results_str) - if args.activations_log_dir is not None: + if activation_hooks is not None: hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args, runtime) return losses, hidden_states_per_batch diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py index ef952dfec6..6e2ba9339a 100644 --- a/modelopt/torch/_compress/utils/utils.py +++ b/modelopt/torch/_compress/utils/utils.py @@ -25,7 +25,6 @@ def __init__(self, device=None, dtype=None): device: `torch.device` to work with dtype: `torch.dtype` to work with - Example:: with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): model = LLaMA(model_config) From 194b5325e583964df1318a184cedd908f393a952 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 25 Nov 2025 00:02:50 +0100 Subject: [PATCH 15/62] Add sewing kit and utilities used for pruning scoring - pruning scoring is self-contained now (#584) ## What does this PR do? Add sewing kit and utilities used for pruning scoring - pruning scoring is self-contained now - no dependency on internal Nvidia code. --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Daniel Korzekwa Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../score_pruning_activations.py | 2 +- modelopt/torch/_compress/compress.py | 4 +- .../torch/_compress/sewing_kit/__init__.py | 34 + modelopt/torch/_compress/sewing_kit/common.py | 19 + modelopt/torch/_compress/sewing_kit/core.py | 881 ++++++++++++++++++ .../_compress/sewing_kit/passage/__init__.py | 28 + .../_compress/sewing_kit/passage/core.py | 459 +++++++++ modelopt/torch/_compress/sewing_kit/utils.py | 506 ++++++++++ modelopt/torch/_compress/tools/__init__.py | 15 + modelopt/torch/_compress/tools/kd_model.py | 53 ++ .../tools/sharded_checkpoint_utils.py | 3 +- .../torch/_compress/tools/validate_model.py | 13 +- .../_compress/utils/checkpoint_manager.py | 276 ++++++ .../torch/_compress/utils/data/dataloaders.py | 326 +++++++ .../torch/_compress/utils/data/dataset.py | 319 +++++++ modelopt/torch/_compress/utils/parsing.py | 455 +++++++++ .../utils/validate_runtime_pipeline.py | 390 ++++++++ modelopt/torch/_compress/utils/validation.py | 826 ++++++++++++++++ pyproject.toml | 7 +- .../configs/validate_model_defaults.yaml | 2 +- tests/gpu/torch/export/test_fsdp2_export.py | 1 - 21 files changed, 4605 insertions(+), 14 deletions(-) create mode 100644 modelopt/torch/_compress/sewing_kit/__init__.py create mode 100644 modelopt/torch/_compress/sewing_kit/common.py create mode 100644 modelopt/torch/_compress/sewing_kit/core.py create mode 100644 modelopt/torch/_compress/sewing_kit/passage/__init__.py create mode 100644 modelopt/torch/_compress/sewing_kit/passage/core.py create mode 100644 modelopt/torch/_compress/sewing_kit/utils.py create mode 100644 modelopt/torch/_compress/tools/__init__.py create mode 100644 modelopt/torch/_compress/tools/kd_model.py create mode 100644 modelopt/torch/_compress/utils/checkpoint_manager.py create mode 100644 modelopt/torch/_compress/utils/data/dataloaders.py create mode 100644 modelopt/torch/_compress/utils/data/dataset.py create mode 100644 modelopt/torch/_compress/utils/parsing.py create mode 100644 modelopt/torch/_compress/utils/validate_runtime_pipeline.py create mode 100644 modelopt/torch/_compress/utils/validation.py diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py index 3617bdb1c2..ef1e6c2738 100644 --- a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -18,7 +18,7 @@ import hydra import torch from omegaconf import DictConfig -from utils.parsing import format_global_config +from modelopt.torch._compress.utils.parsing import format_global_config from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 7d955c5cae..64e241d104 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -23,10 +23,10 @@ import build_library_and_stats import mip_and_realize_models import pruning_ckpts -import score_pruning_activations +import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations import scoring from omegaconf import DictConfig -from puzzle_tools.runtime import IRuntime +from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir diff --git a/modelopt/torch/_compress/sewing_kit/__init__.py b/modelopt/torch/_compress/sewing_kit/__init__.py new file mode 100644 index 0000000000..6df9f8afa8 --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/__init__.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from .core import ( + Needle, + KnotException, + LoopFoundException, + InputsLoopFoundException, + MultipleExternalNodesException, + OnlyInternalNodesException, + OutputsLoopFoundException, + ExternalTarget, + ModuleTarget, + ConstantTarget, + FunctionTarget, + RemoteTarget, + StitchedModule, + StitchedModuleException, + CantResolveNodeDependenciesException, + StitchedModuleOutput, +) +from .passage import always_false_predicate, always_true_predicate, InputArgs diff --git a/modelopt/torch/_compress/sewing_kit/common.py b/modelopt/torch/_compress/sewing_kit/common.py new file mode 100644 index 0000000000..5bc5732320 --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/common.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +logger = logging.getLogger("sewing_kit") +logger.setLevel(logging.WARN) diff --git a/modelopt/torch/_compress/sewing_kit/core.py b/modelopt/torch/_compress/sewing_kit/core.py new file mode 100644 index 0000000000..550c1298ca --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/core.py @@ -0,0 +1,881 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +from __future__ import annotations +from abc import ABC +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union +from typing_extensions import override + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +import torch +import torch.distributed +import torch.nn as nn + +from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip +from .passage import ( + Passage, + InputArgs, + OutputValue, + Predicate, + always_false_predicate, + PassageInputAdapter, + PassageOutputAdapter, + PassageInputOverrides, + PassageOutputOverrides, +) + + +InputAdapter = Callable[[InputArgs], InputArgs] +OutputAdapter = Callable[..., OutputValue] + + +def default_input_adapter_fn(input_values: InputArgs) -> InputArgs: + return input_values + + +def default_output_adapter_fn(v: OutputValue) -> OutputValue: + return v + + +@dataclass +class IOReducer: + pass + + +def default_input_reducer_fn(acc: InputArgs, input_override: InputArgs, *args): + return acc + input_override + + +@dataclass +class InputReducer(IOReducer): + reducer_fn: Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs] = ( + default_input_reducer_fn + ) + + def __call__( + self, + acc: InputArgs, + input_override: InputArgs, + original_input: InputArgs, + index: int, + all_input_overrides: list[InputArgs], + ) -> InputArgs: + result = self.reducer_fn(acc, input_override, original_input, index, all_input_overrides) + return result + + @classmethod + def default(cls) -> InputReducer: + return InputReducer() + + +def default_output_reducer_fn(acc: OutputValue, input_override: OutputValue, *args): + return input_override + + +@dataclass +class OutputReducer(IOReducer): + reducer_fn: Callable[ + [OutputValue, OutputValue, Optional[OutputValue], int, list[OutputValue]], OutputValue + ] = default_output_reducer_fn + requires_original_output: bool = False + + def __call__( + self, + acc: OutputValue, + output_override: OutputValue, + original_output: Optional[OutputValue], + index: int, + all_output_overrides: list[OutputValue], + ) -> InputArgs: + result = self.reducer_fn(acc, output_override, original_output, index, all_output_overrides) + return result + + @classmethod + def default(cls) -> OutputReducer: + return OutputReducer() + + +class Singleton(type): + _instances = {} + + @override + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class Target: + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class TargetWithInput(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def input( + self, + adapter: InputAdapter = default_input_adapter_fn, + reducer: InputReducer = InputReducer.default(), + ) -> InputDescriptor: + result = InputDescriptor(self, input_name="", input_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithNamedInputs(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def input( + self, + name: str, + adapter: InputAdapter = default_input_adapter_fn, + reducer: InputReducer = InputReducer.default(), + ) -> InputDescriptor: + result = InputDescriptor(self, input_name=name, input_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithOutput(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def output( + self, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name="", output_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithNamedOutputs(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def output( + self, + name: str, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class ExternalTarget(TargetWithNamedInputs, TargetWithNamedOutputs, metaclass=Singleton): + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class ConstantTarget(TargetWithOutput): + name: str + value: Any + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class FunctionTarget(TargetWithInput, TargetWithOutput): + name: str + function: Callable[..., Any] + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class ModuleTarget(TargetWithNamedInputs, TargetWithNamedOutputs): + name: str + module: nn.Module + + @override + def __str__(self) -> str: + return f"ModuleTarget({self.name})" + + @override + def __repr__(self) -> str: + return str(self) + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class RemoteTarget(Target): + peer_rank: Union[int, Sequence[int]] + process_group: Optional[torch.distributed.ProcessGroup] = None + blocking: bool = True + + @override + def __hash__(self) -> int: + return super().__hash__() + + def value( + self, + name: str, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer) + return result + + +@dataclass(frozen=True, eq=True) +class RemoteDataDescriptor(ABC): + key: str + + +@dataclass(frozen=True, eq=True) +class RemoteTensorDataDescriptor(RemoteDataDescriptor): + device: Literal["cuda", "cpu"] + dtype: torch.dtype + shape: torch.Size + + +@dataclass(frozen=True, eq=True) +class RemotePythonDataDescriptor(RemoteDataDescriptor): + value: Any + + +@dataclass +class Node: + target: Target + stitches_to: list[StitchDescriptor] = field(default_factory=list) + stitches_from: list[StitchDescriptor] = field(default_factory=list) + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class InputDescriptor: + target: Target + input_name: str = "" + input_adapter: InputAdapter = field(default=default_input_adapter_fn) + reducer: InputReducer = field(default_factory=InputReducer.default) + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class OutputDescriptor: + target: Target + output_name: str = "" + output_adapter: OutputAdapter = field(default=default_output_adapter_fn) + reducer: OutputReducer = field(default_factory=OutputReducer.default) + + @override + def __hash__(self) -> int: + return id(self) + + +IODescriptor = Union[InputDescriptor, OutputDescriptor] + + +@dataclass +class StitchDescriptor: + source_descriptor: IODescriptor + destination_descriptor: IODescriptor + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class StitchedModuleOutput: + captured_inputs: dict[str, InputArgs] + captured_outputs: dict[str, Any] + + +class StitchedModuleException(Exception): + pass + + +class CantResolveNodeDependenciesException(StitchedModuleException): + pass + + +class StitchedModule(nn.Module): + def __init__( + self, + nodes: dict[Target, Node], + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit=True, + ignore_extra_overrides=False, + ) -> None: + super().__init__() + self.nodes = nodes + self.ignore_extra_overrides = ignore_extra_overrides + external_nodes = [n for n in nodes.values() if isinstance(n.target, ExternalTarget)] + remote_nodes = [n for n in nodes.values() if isinstance(n.target, RemoteTarget)] + assert len(external_nodes) <= 1 + assert len(remote_nodes) + len(external_nodes) > 0 + self.external_node = external_nodes[0] if len(external_nodes) > 0 else None + self.internal_nodes = [ + n for n in nodes.values() if not isinstance(n.target, ExternalTarget) + ] + self.values_from_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict) + self.values_to_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict) + + self.node_passages: dict[Node, Passage] = { + node: Passage.create( + module=node.target.module, + inputs_to_capture=set( + s.source_descriptor.input_name + for s in node.stitches_from + if isinstance(s.source_descriptor, InputDescriptor) + ), + outputs_to_capture=set( + s.source_descriptor.output_name + for s in node.stitches_from + if isinstance(s.source_descriptor, OutputDescriptor) + ), + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + name=getattr(node.target, "name", None), + ) + for node in self.internal_nodes + if isinstance(node.target, ModuleTarget) + } + + self.passage_modules = nn.ModuleDict( + { + f"node_{node_index}": self.node_passages[node] + for node_index, node in enumerate(nodes.values()) + if node in self.node_passages + } + ) + self.adapter_modules = nn.ModuleDict( + { + f"node_{node_index}__stitch_{stitch_index}__{descriptor_name}": adapter + for node_index, node in enumerate(nodes.values()) + for stitch_index, stitch in enumerate(node.stitches_from + node.stitches_to) + for descriptor_name, descriptor in ( + ("source", stitch.source_descriptor), + ("destination", stitch.destination_descriptor), + ) + for adapter in [ + descriptor.input_adapter + if isinstance(descriptor, InputDescriptor) + else descriptor.output_adapter + ] + if isinstance(adapter, nn.Module) + } + ) + + def create_input_overrides( + self, values_to_node: dict[IODescriptor, Any] + ) -> PassageInputOverrides: + input_descriptors_by_group = defaultdict[str, list[InputDescriptor]](list) + for io_descriptor in values_to_node.keys(): + if isinstance(io_descriptor, InputDescriptor): + input_descriptors_by_group[io_descriptor.input_name].append(io_descriptor) + + input_overrides = PassageInputOverrides() + for group, input_descriptors in input_descriptors_by_group.items(): + reducers = [d.reducer for d in input_descriptors] + + def create_reducer(input_descriptors=input_descriptors, reducers=reducers): + inputs = [values_to_node[d] for d in input_descriptors] + + def reducer_fn( + original_input: InputArgs, + module_name: Optional[str], + module: Optional[nn.Module], + ) -> InputArgs: + acc = InputArgs() + for i, (input_, reducer) in enumerate(zip(inputs, reducers)): + acc = reducer(acc, input_, original_input, i, inputs) + return acc + + return reducer_fn + + input_override = PassageInputAdapter(create_reducer()) + input_overrides[group] = input_override + + return input_overrides + + def create_output_overrides( + self, values_to_node: dict[IODescriptor, Any] + ) -> PassageOutputOverrides: + output_descriptors_by_group = defaultdict[str, list[OutputDescriptor]](list) + for io_descriptor in values_to_node.keys(): + if isinstance(io_descriptor, OutputDescriptor): + output_descriptors_by_group[io_descriptor.output_name].append(io_descriptor) + + output_overrides = PassageOutputOverrides() + for group, output_descriptors in output_descriptors_by_group.items(): + reducers = [d.reducer for d in output_descriptors] + requires_original_output = any(r.requires_original_output for r in reducers) + + def create_reducer(reducers=reducers): + outputs = [values_to_node[d] for d in output_descriptors] + + def reducer_fn( + original_output: Optional[OutputValue], + module_name: Optional[str], + module: Optional[nn.Module], + ) -> OutputValue: + acc = None + for i, (output, reducer) in enumerate(zip(outputs, reducers)): + acc = reducer(acc, output, original_output, i, outputs) + return acc + + return reducer_fn + + reducer_fn = create_reducer() + if requires_original_output: + output_override = PassageOutputAdapter(reducer_fn) + else: + output_override = reducer_fn(None, None, None) + + output_overrides[group] = output_override + + return output_overrides + + @override + def __call__( + self, + input_overrides: dict[str, Any], + output_overrides: dict[str, Any], + *args, + **kwargs, + ) -> StitchedModuleOutput: + return super().__call__(input_overrides, output_overrides, *args, **kwargs) + + @override + @dynamo_skip + def forward( + self, + input_overrides: dict[str, Any], + output_overrides: dict[str, Any], + *args, + **kwargs, + ) -> StitchedModuleOutput: + input_overrides = {k: InputArgs.from_value(v) for k, v in input_overrides.items()} + + self.values_from_node.clear() + self.values_to_node.clear() + + unresolved_count: int = 0 + nodes_stack: list[Node] = ( + [] if self.external_node is None else [self.external_node] + ) + self.internal_nodes + while len(nodes_stack) > 0: + node = nodes_stack.pop(0) + values_from_node = self.values_from_node[node] + values_to_node = self.values_to_node[node] + + if isinstance(node.target, ExternalTarget): + assert self.external_node is not None + + if not self.ignore_extra_overrides: + input_override_names = set(input_overrides.keys()) + external_node_input_names = set( + s.source_descriptor.input_name + for s in self.external_node.stitches_from + if isinstance(s.source_descriptor, InputDescriptor) + ) + assert input_override_names == external_node_input_names + output_override_names = set(output_overrides.keys()) + external_node_output_names = set( + s.source_descriptor.output_name + for s in self.external_node.stitches_from + if isinstance(s.source_descriptor, OutputDescriptor) + ) + assert output_override_names == external_node_output_names + + for stitch in self.external_node.stitches_from: + if isinstance(stitch.source_descriptor, InputDescriptor): + orig_input_override = input_overrides[stitch.source_descriptor.input_name] + input_override = stitch.source_descriptor.input_adapter(orig_input_override) + values_from_node[stitch.source_descriptor] = input_override + elif isinstance(stitch.source_descriptor, OutputDescriptor): + orig_output_override = output_overrides[ + stitch.source_descriptor.output_name + ] + output_override = stitch.source_descriptor.output_adapter( + orig_output_override + ) + values_from_node[stitch.source_descriptor] = output_override + else: + raise RuntimeError("Shouldn't happen") + + else: + if len(values_to_node) < len(node.stitches_to): + nodes_stack.append(node) + unresolved_count += 1 + if unresolved_count >= len(nodes_stack): + raise CantResolveNodeDependenciesException( + "Can't resolve nodes dependencies" + ) + continue + + if isinstance(node.target, ConstantTarget): + assert len(values_to_node) == 0 + + output_value = node.target.value + + for stitch in node.stitches_from: + assert isinstance(stitch.source_descriptor, OutputDescriptor) + assert stitch.source_descriptor.output_name == "" + value = stitch.source_descriptor.output_adapter(output_value) + values_from_node[stitch.source_descriptor] = value + + elif isinstance(node.target, FunctionTarget): + assert all( + isinstance(v, InputDescriptor) and v.input_name == "" + for v in values_to_node + ) + + function_input_overrides = self.create_input_overrides(values_to_node)[""] + + if isinstance(function_input_overrides, InputArgs): + input_args = function_input_overrides + else: + input_args = function_input_overrides(InputArgs(), None, None) + + function_output = node.target.function(*input_args.args, **input_args.kwargs) + + for stitch in node.stitches_from: + assert isinstance(stitch.source_descriptor, OutputDescriptor) + assert stitch.source_descriptor.output_name == "" + value = stitch.source_descriptor.output_adapter(function_output) + values_from_node[stitch.source_descriptor] = value + + elif isinstance(node.target, ModuleTarget): + passage = self.node_passages[node] + passage.input_overrides = self.create_input_overrides(values_to_node) + passage.output_overrides = self.create_output_overrides(values_to_node) + passage_output = passage(*args, **kwargs) + + for stitch in node.stitches_from: + if isinstance(stitch.source_descriptor, InputDescriptor): + captured_input = passage_output.captured_inputs[ + stitch.source_descriptor.input_name + ] + value = stitch.source_descriptor.input_adapter(captured_input) + values_from_node[stitch.source_descriptor] = value + elif isinstance(stitch.source_descriptor, OutputDescriptor): + captured_output = passage_output.captured_outputs[ + stitch.source_descriptor.output_name + ] + value = stitch.source_descriptor.output_adapter(captured_output) + values_from_node[stitch.source_descriptor] = value + else: + raise RuntimeError("Shouldn't happen") + + elif isinstance(node.target, RemoteTarget): + assert all( + isinstance(v, OutputDescriptor) and v.output_name != "" + for v in values_from_node + ) + assert all( + isinstance(v, OutputDescriptor) and v.output_name != "" + for v in values_to_node + ) + + process_group = node.target.process_group + peers = node.target.peer_rank + if not isinstance(peers, Sequence): + peers = [peers] + + if len(values_to_node) > 0: + items_to_send = list(self.create_output_overrides(values_to_node).items()) + + data_descriptors: list[RemoteDataDescriptor] = [] + tensors_to_send: list[torch.Tensor] = [] + + for key, value in items_to_send: + if isinstance(value, torch.Tensor): + if value.is_cuda: + tensor_device = "cuda" + elif value.is_cpu: + tensor_device = "cpu" + else: + raise RuntimeError( + f"Invalid tensor device to send to remote target: {value.device}" + ) + + data_descriptor = RemoteTensorDataDescriptor( + key=key, + device=tensor_device, + dtype=value.dtype, + shape=value.shape, + ) + tensors_to_send.append(value) + + else: + data_descriptor = RemotePythonDataDescriptor( + key=key, + value=value, + ) + + data_descriptors.append(data_descriptor) + + works: list[Optional[torch.distributed.Work]] = [] + for peer in peers: + if process_group is not None: + peer = torch.distributed.get_global_rank(process_group, peer) + + peer_works = distributed_isend_obj(data_descriptors, dst=peer) + works.extend(peer_works) + + for tensor in tensors_to_send: + work = torch.distributed.isend(tensor, dst=peer) + works.append(work) + + if node.target.blocking: + for work in works: + if work is not None: + work.wait() + + pass + + if len(node.stitches_from) > 0: + assert len(peers) == 1, ( + f"Cannot use multiple peers when using RemoteTarget as a source ({peers=})" + ) + (peer,) = peers + + if process_group is not None: + peer = torch.distributed.get_global_rank(process_group, peer) + + data_descriptors = distributed_recv_obj(src=peer) + assert isinstance(data_descriptors, list) + + tensors_to_recv: list[torch.Tensor] = [] + received_values: dict[str, Any] = {} + for data_descriptor in data_descriptors: + if isinstance(data_descriptor, RemoteTensorDataDescriptor): + tensor = torch.empty( + data_descriptor.shape, + dtype=data_descriptor.dtype, + device=data_descriptor.device, + ) + tensors_to_recv.append(tensor) + received_values[data_descriptor.key] = tensor + elif isinstance(data_descriptor, RemotePythonDataDescriptor): + received_values[data_descriptor.key] = data_descriptor.value + else: + raise RuntimeError( + f"Received invalid data descriptor from remote peer: {data_descriptor}" + ) + + works: list[Optional[torch.distributed.Work]] = [] + for tensor in tensors_to_recv: + work = torch.distributed.irecv(tensor, src=peer) + works.append(work) + + for work in works: + if work is not None: + work.wait() + + for stitch in node.stitches_from: + if isinstance(stitch.source_descriptor, OutputDescriptor): + remote_output = received_values[ + stitch.source_descriptor.output_name + ] + value = stitch.source_descriptor.output_adapter(remote_output) + values_from_node[stitch.source_descriptor] = value + else: + raise RuntimeError("Shouldn't happen") + else: + raise RuntimeError("Shouldn't happen") + + for stitch in node.stitches_from: + dst_node = self.nodes[stitch.destination_descriptor.target] + value = values_from_node[stitch.source_descriptor] + + if isinstance(stitch.destination_descriptor, InputDescriptor): + value = stitch.destination_descriptor.input_adapter(value) + elif isinstance(stitch.destination_descriptor, OutputDescriptor): + value = stitch.destination_descriptor.output_adapter(value) + else: + raise RuntimeError("Shouldn't happen") + + self.values_to_node[dst_node][stitch.destination_descriptor] = value + + unresolved_count = 0 + + values_to_external_node = ( + {} if self.external_node is None else self.values_to_node[self.external_node] + ) + output = StitchedModuleOutput( + captured_inputs={ + k.input_name: v + for k, v in values_to_external_node.items() + if isinstance(k, InputDescriptor) + }, + captured_outputs={ + k.output_name: v + for k, v in values_to_external_node.items() + if isinstance(k, OutputDescriptor) + }, + ) + + self.values_from_node.clear() + self.values_to_node.clear() + + return output + + +class KnotException(Exception): + pass + + +class LoopFoundException(KnotException): + pass + + +class InputsLoopFoundException(LoopFoundException): + pass + + +class OutputsLoopFoundException(LoopFoundException): + pass + + +class MultipleExternalNodesException(KnotException): + pass + + +class OnlyInternalNodesException(KnotException): + pass + + +class Needle: + def __init__(self) -> None: + self.nodes = dict[Target, Node]() + + def get_node_for_target(self, target: Target) -> Node: + if target not in self.nodes: + node = Node(target=target) + self.nodes[target] = node + else: + node = self.nodes[target] + + return node + + def stitch(self, src: IODescriptor, dst: IODescriptor) -> Self: + descriptor = StitchDescriptor(source_descriptor=src, destination_descriptor=dst) + + src_node = self.get_node_for_target(descriptor.source_descriptor.target) + dst_node = self.get_node_for_target(descriptor.destination_descriptor.target) + + if descriptor not in src_node.stitches_from: + src_node.stitches_from.append(descriptor) + + if descriptor not in dst_node.stitches_to: + dst_node.stitches_to.append(descriptor) + + return self + + def _search_loops( + self, + node: Node, + expand_fn: Callable[[Node], Iterable[IODescriptor]], + traversed_nodes: Optional[set[Node]] = None, + ) -> bool: + if isinstance(node.target, ExternalTarget): + return False + + if traversed_nodes is None: + traversed_nodes = set() + + if node in traversed_nodes: + found_loop = True + else: + traversed_nodes = traversed_nodes | {node} + found_loop = False + descriptors = expand_fn(node) + for descriptor in descriptors: + stitch_node = self.get_node_for_target(descriptor.target) + found_loop |= self._search_loops(stitch_node, expand_fn, traversed_nodes) + + return found_loop + + def _validate_nodes(self): + # internal_nodes = [n for n in self.nodes.values() if not isinstance(n.target, (ExternalTarget, RemoteTarget))] + external_nodes = [n for n in self.nodes.values() if isinstance(n.target, ExternalTarget)] + remote_nodes = [n for n in self.nodes.values() if isinstance(n.target, RemoteTarget)] + + if len(external_nodes) + len(remote_nodes) == 0: + raise OnlyInternalNodesException(f"Has only internal nodes") + + if len(external_nodes) > 1: + raise MultipleExternalNodesException( + f"Expected no more than 1 external node, found {len(external_nodes)}" + ) + + for i, node in enumerate(self.nodes.values()): + found_inputs_loop = self._search_loops( + node, lambda n: [s.source_descriptor for s in n.stitches_to] + ) + if found_inputs_loop: + raise InputsLoopFoundException(f"Found a loop in inputs of node {i}: {node}") + + found_outputs_loop = self._search_loops( + node, lambda n: [s.destination_descriptor for s in n.stitches_from] + ) + if found_outputs_loop: + raise OutputsLoopFoundException(f"Found a loop in outputs of node {i}: {node}") + + def knot( + self, + capture_cache_outputs_predicate=always_false_predicate, + early_exit=True, + ignore_extra_overrides=False, + ) -> StitchedModule: + self._validate_nodes() + + module = StitchedModule( + nodes=self.nodes, + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + ignore_extra_overrides=ignore_extra_overrides, + ) + + return module diff --git a/modelopt/torch/_compress/sewing_kit/passage/__init__.py b/modelopt/torch/_compress/sewing_kit/passage/__init__.py new file mode 100644 index 0000000000..98dfc683b3 --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/passage/__init__.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .core import ( + Passage, + PassageOutput, + InputArgs, + OutputValue, + Predicate, + PassageInputAdapter, + PassageOutputAdapter, + PassageInputOverrides, + PassageOutputOverrides, + always_true_predicate, + always_false_predicate, +) diff --git a/modelopt/torch/_compress/sewing_kit/passage/core.py b/modelopt/torch/_compress/sewing_kit/passage/core.py new file mode 100644 index 0000000000..4a66638aac --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/passage/core.py @@ -0,0 +1,459 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +from __future__ import annotations +import sys + +from collections.abc import Sequence, Callable + +from dataclasses import dataclass +from typing import Any, ContextManager, Iterable, Mapping, Optional, Union + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +from typing_extensions import override + +import torch.nn as nn +from ..utils import ( + ActivityContext, + has_fake_tensor, + fake_tensors, + is_submodule_of, + is_submodule_or_same, + real_tensors, + dynamo_skip, +) +from ..common import logger + + +@dataclass +class InputArgs: + args: list[Any] + kwargs: dict[str, Any] + + def __init__(self, *args, **kwargs): + self.args = list(args) + self.kwargs = dict(kwargs) + + def __add__(self, other: Any) -> InputArgs: + assert isinstance(other, InputArgs) + result = InputArgs(*self.args, *other.args, **{**self.kwargs, **other.kwargs}) + return result + + def drop_args(self, index: int | slice | None = None) -> InputArgs: + new_args = InputArgs(*self.args, **self.kwargs) + if index is None: + new_args.args.clear() + else: + del new_args.args[index] + + return new_args + + def drop_kwargs(self, keys: Sequence[str] | None = None) -> InputArgs: + new_args = InputArgs(*self.args, **self.kwargs) + if keys is None: + new_args.kwargs.clear() + else: + for key in keys: + new_args.kwargs.pop(key, None) + + return new_args + + @classmethod + def from_value(cls, v): + if isinstance(v, cls): + return v + elif isinstance(v, InputArgs): + return cls(*v.args, **v.kwargs) + elif isinstance(v, Sequence): + return cls(*v) + else: + return cls(v) + + +OutputValue = Any + + +@dataclass +class PassageInputAdapter: + adapter_fn: Callable[[InputArgs, Optional[str], Optional[nn.Module]], InputArgs] + + def __call__( + self, original_input: InputArgs, module_name: Optional[str], module: Optional[nn.Module] + ) -> InputArgs: + result = self.adapter_fn(original_input, module_name, module) + return result + + +@dataclass +class PassageOutputAdapter: + adapter_fn: Callable[[Any, Optional[str], Optional[nn.Module]], Any] + + def __call__( + self, original_output: Any, module_name: Optional[str], module: Optional[nn.Module] + ) -> Any: + result = self.adapter_fn(original_output, module_name, module) + return result + + +class PassageInputOverrides(dict[str, Union[PassageInputAdapter, InputArgs]]): + def __init__(self, input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}): + for k, v in input_overrides.items(): + self[k] = v + + # def __setitem__(self, key: str, value: InputAdapter | InputArgs) -> None: + # if isinstance(key, InputArgs): + # def adapter_fn(original_input: InputArgs) -> InputArgs: + # assert isinstance(value, InputArgs) + # return value + # self[key] = InputAdapter(adapter_fn) + # else: + # self[key] = value + + +class PassageOutputOverrides(dict[str, Union[PassageOutputAdapter, Any]]): + def __init__(self, output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}): + for k, v in output_overrides.items(): + self[k] = v + + +class NoActivePassageContextError(RuntimeError): + pass + + +class RequiredPassageOutputsCapturedSignal(Exception): + pass + + +@dataclass +class PassageOutput: + captured_inputs: dict[str, InputArgs] + captured_outputs: dict[str, Any] + captured_fake_outputs: dict[str, Any] + module_output: Any + + +Predicate = Callable[[str, nn.Module], bool] + + +def always_false_predicate(module_name: str, module: nn.Module) -> bool: + return False + + +def always_true_predicate(module_name: str, module: nn.Module) -> bool: + return True + + +class Passage(nn.Module): + create_fn_context = ActivityContext[None](max_depth=1) + active_passages_context = ActivityContext["Passage"](no_duplicates=True, reversed=True) + + def __init__( + self, + module: nn.Module, + *, + inputs_to_capture: Iterable[str] = [], + outputs_to_capture: Iterable[str] = [], + input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}, + output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}, + outputs_cache: dict[str, Any] = {}, + capture_fake_outputs_predicate: Predicate = always_false_predicate, + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit: bool = False, + name: Optional[str] = None, + ): + super().__init__() + + if not self.create_fn_context.is_active(): + raise RuntimeError("Please use Passage.create(...) in order to create a new Passage") + + self.active_context_manager: Optional[ContextManager] = None + + self.name = name + self.module = module + self.module_to_name_mapping = {id(v): k for k, v in module.named_modules()} + self.inputs_to_capture = set(inputs_to_capture) + self.outputs_to_capture = set(outputs_to_capture) + self.input_overrides = input_overrides + self.output_overrides = output_overrides + self.outputs_cache = outputs_cache + self.capture_fake_outputs_predicate = capture_fake_outputs_predicate + self.capture_cache_outputs_predicate = capture_cache_outputs_predicate + self.early_exit = early_exit + + self.reset() + + @property + def input_overrides(self) -> PassageInputOverrides: + return self._input_overrides + + @input_overrides.setter + def input_overrides(self, value: Mapping[str, PassageInputAdapter | InputArgs]): + self._input_overrides = PassageInputOverrides(value) + + @property + def output_overrides(self) -> PassageOutputOverrides: + return self._output_overrides + + @output_overrides.setter + def output_overrides(self, value: Mapping[str, PassageOutputAdapter | Any]): + self._output_overrides = PassageOutputOverrides(value) + + def reset(self): + self.required_capture_count = ( + (len(self.inputs_to_capture) + len(self.outputs_to_capture)) + if self.early_exit + else None + ) + self.captured_outputs: dict[str, Any] = {} + self.captured_inputs: dict[str, InputArgs] = {} + self.captured_fake_outputs: dict[str, Any] = {} + + @classmethod + def module_name_relative_to_active_passage(cls, module: PatchedModule) -> str: + root_passage = Passage.active_passages_context.get_active() + assert root_passage is not None + module_name = root_passage.module_to_name_mapping[id(module)] + return module_name + + @classmethod + def create( + cls, + module: nn.Module, + *, + inputs_to_capture: Iterable[str] = [], + outputs_to_capture: Iterable[str] = [], + input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}, + output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}, + outputs_cache: dict[str, Any] = {}, + capture_fake_outputs_predicate: Predicate = always_false_predicate, + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit: bool = False, + name: Optional[str] = None, + ) -> Passage: + with cls.create_fn_context(None): + passage = cls( + module=module, + inputs_to_capture=inputs_to_capture, + outputs_to_capture=outputs_to_capture, + input_overrides=input_overrides, + output_overrides=output_overrides, + outputs_cache=outputs_cache, + capture_fake_outputs_predicate=capture_fake_outputs_predicate, + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + name=name, + ) + + for submodule_name, submodule in module.named_modules(remove_duplicate=False): + patch_module(submodule_name, submodule) + + # register_passage_hooks(module, descriptor) + + return passage + + def is_active(self) -> bool: + result = self.active_context_manager is not None + return result + + def __enter__(self): + assert self.active_context_manager is None + self.active_context_manager = Passage.active_passages_context(self) + self.active_context_manager.__enter__() + self.module_to_name_mapping = {id(v): k for k, v in self.named_modules()} + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.active_context_manager is not None + self.active_context_manager.__exit__(exc_type, exc_val, exc_tb) + + def freeze(self): + self.eval() + self.requires_grad_(False) + + def unfreeze(self): + self.train() + self.requires_grad_(True) + + def run(self, *args, **kwargs) -> PassageOutput: + return self(*args, **kwargs) + + @override + def __call__(self, *args, **kwargs) -> PassageOutput: + return super().__call__(*args, **kwargs) + + @dynamo_skip + @override + def forward(self, *args, **kwargs) -> PassageOutput: + self.reset() + + with Passage.active_passages_context(self): + try: + module_output = self.module(*args, **kwargs) + except RequiredPassageOutputsCapturedSignal: + module_output = None + + output = PassageOutput( + captured_inputs=self.captured_inputs, + captured_outputs=self.captured_outputs, + captured_fake_outputs=self.captured_fake_outputs, + module_output=module_output, + ) + + self.reset() + + return output + + +class PatchedModule: ... + + +def patch_module(module_name_: str, module: nn.Module): + # orig_forward = module.forward + + if isinstance(module, PatchedModule): + # if module_name != Passage.module_name_relative_to_active_passage(module): + # logger.warn(f'Module "{module_name}" already patched for module "{Passage.module_name_relative_to_active_passage(module)}". Could lead to bugs.') + return + + orig_class = module.__class__ + + class PassageModuleWrapper(orig_class, PatchedModule): + # Defined as a static method to avoid potential collision with original class methods + @staticmethod + @dynamo_skip + def can_be_skipped(_self: PassageModuleWrapper, depth: int) -> bool: + passages_beyond_depth = Passage.active_passages_context[depth:] + module_name = Passage.module_name_relative_to_active_passage(_self) + + results = [ + ( + module_name in passage.outputs_cache + and not any( + is_submodule_or_same(k, module_name) for k in passage.outputs_to_capture + ) + and not any( + is_submodule_of(k, module_name) + for k, v in passage.input_overrides.items() + if v is not None + ) + and not any( + is_submodule_of(k, module_name) + for k, v in passage.output_overrides.items() + if v is not None + ) + ) + for passage in passages_beyond_depth + ] + + result = all(results) + + return result + + # Defined as a static method to avoid potential collision with original class methods + @staticmethod + @dynamo_skip + def run_passage(_self: PassageModuleWrapper, depth: int, args, kwargs): + if depth + 1 > len(Passage.active_passages_context): + output = super(PassageModuleWrapper, _self).__call__(*args, **kwargs) + return output + + module_name = Passage.module_name_relative_to_active_passage(_self) + passage = Passage.active_passages_context[depth] + + has_output_override = module_name in passage.output_overrides + output_override = passage.output_overrides.get(module_name) + + if has_output_override and not isinstance(output_override, PassageOutputAdapter): + output = output_override + else: + input_override = passage.input_overrides.get(module_name) + if input_override is not None: + original_input_args = InputArgs(*args, **kwargs) + + if isinstance(input_override, PassageInputAdapter): + new_input_args = input_override(original_input_args, module_name, module) + else: + new_input_args = input_override + + args, kwargs = new_input_args.args, new_input_args.kwargs + + if ( + output_override is None + and PassageModuleWrapper.can_be_skipped(_self, depth) + and (has_fake_tensor(args) or has_fake_tensor(kwargs)) + ): + cached_output = passage.outputs_cache[module_name] + return cached_output + + output = PassageModuleWrapper.run_passage( + _self=_self, + depth=depth + 1, + args=args, + kwargs=kwargs, + ) + + if isinstance(output_override, PassageOutputAdapter): + output = output_override(output, module_name, module) + + if passage.capture_fake_outputs_predicate(module_name, module): + fake_output = fake_tensors(output) + passage.captured_fake_outputs[module_name] = fake_output + + if not module_name in passage.outputs_cache and passage.capture_cache_outputs_predicate( + module_name, module + ): + fake_output = fake_tensors(output) + passage.outputs_cache[module_name] = fake_output + + if module_name in passage.inputs_to_capture: + real_args, real_kwargs = real_tensors(args), real_tensors(kwargs) + passage.captured_inputs[module_name] = InputArgs(*real_args, **real_kwargs) + + if passage.required_capture_count is not None: + passage.required_capture_count -= 1 + + if module_name in passage.outputs_to_capture: + real_output = real_tensors(output) + output_value = real_output + passage.captured_outputs[module_name] = output_value + + if passage.required_capture_count is not None: + passage.required_capture_count -= 1 + + if passage.required_capture_count == 0: + raise RequiredPassageOutputsCapturedSignal() + + return output + + @dynamo_skip + @override + def __call__(self, *args, **kwargs): + output = self.run_passage( + _self=self, + depth=0, + args=args, + kwargs=kwargs, + ) + return output + + # module.forward = forward + PassageModuleWrapper.__name__ = f"ModuleWrapper({module.__class__.__name__})" + module.__class__ = PassageModuleWrapper diff --git a/modelopt/torch/_compress/sewing_kit/utils.py b/modelopt/torch/_compress/sewing_kit/utils.py new file mode 100644 index 0000000000..ebe90b2a44 --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/utils.py @@ -0,0 +1,506 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from __future__ import annotations + +import inspect +from collections.abc import Sequence, Mapping +from contextlib import contextmanager +from typing import ( + Any, + Callable, + ContextManager, + Generic, + Iterable, + Literal, + Optional, + Protocol, + TypeVar, + cast, + overload, +) +from typing_extensions import override +import torch +import torch.distributed +import torch._dynamo +import torch._C +from torch import Tensor +import torch.utils._pytree as pytree +import torch.nn as nn +import torch.nn.functional as F +from torch._subclasses import FakeTensor, FakeTensorMode + + +Fn = TypeVar("Fn", bound=Callable) + + +class DynamoSkip(Protocol): + @overload + def __call__(self, fn: None = None) -> Callable[[Fn], Fn]: ... + @overload + def __call__(self, fn: Fn) -> Fn: ... + + +class DynamoDisable(Protocol): + @overload + def __call__(self, fn: None = None, disable: bool = False) -> Callable[[Fn], Fn]: ... + @overload + def __call__(self, fn: Fn, disable: bool = False) -> Fn: ... + + +try: + dynamo_skip: DynamoSkip = cast(Any, torch._dynamo.decorators).skip + dynamo_disable: DynamoDisable = cast(Any, torch._dynamo.decorators).disable +except: + dynamo_skip: DynamoSkip = cast(Any, torch._dynamo.eval_frame).skip + dynamo_disable: DynamoDisable = cast(Any, torch._dynamo.eval_frame).disable + + +TModule = TypeVar("TModule", bound=nn.Module) + + +class ModuleRef(Generic[TModule]): + def __init__(self, module: TModule): + self.module = module + + +Reduction = Literal["none", "mean", "sum"] + + +def normalized_mse_loss( + input: Tensor, target: Tensor, reduction: Reduction = "mean", epsilon: float = 1e-6 +): + loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction=reduction + ) + return loss + + +def mse_loss(input: Tensor, target: Tensor, reduction: Reduction = "mean", epsilon: float = 1e-6): + loss = F.mse_loss(input, target, reduction=reduction) + return loss + + +class NormalizedMSELoss(nn.modules.loss._Loss): + __constants__ = ["reduction", "epsilon"] + + def __init__(self, reduction: Reduction = "mean", epsilon: float = 1e-6) -> None: + super().__init__(None, None, reduction) + self.epsilon = epsilon + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + loss = normalized_mse_loss( + input, + target, + cast(Reduction, self.reduction), + self.epsilon, + ) + return loss + + +def vectorwise_normalized_mse_loss(input: Tensor, target: Tensor, epsilon: float = 1e-6): + """ + Like normalized_mse_loss, but the input is treated as a multi-dimensional batch of vectors. + Normalization is done on each vector separately (the last dim), then results are averaged. + """ + return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) + + +def batched_normalized_mse_loss( + input: Tensor, target: Tensor, epsilon: float = 1e-6, batch_dims: Sequence[int] = (0,) +): + """ + Like normalized_mse_loss, but the input is treated as a batch of tensors. + Normalization is done on the non-batch dims, then results are averaged. + """ + norm_dims = list(set(range(input.ndim)) - set(batch_dims)) + norm_of_target_vectors = F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction="none" + ).mean(dim=norm_dims) + vectorwise_mse = F.mse_loss(input, target, reduction="none").mean(dim=norm_dims) + normalized_vectorwise_mse = vectorwise_mse / norm_of_target_vectors + loss = normalized_vectorwise_mse.mean() + return loss + + +class ActivityContextMaxDepthException(Exception): + pass + + +class ActivityContextDuplicateException(Exception): + pass + + +T = TypeVar("T") + + +class ActivityContext(Generic[T]): + def __init__(self, max_depth: Optional[int] = None, no_duplicates=False, reversed=False): + self.activity_stack: list[T] = [] + self.max_depth = max_depth + self.no_duplicates = no_duplicates + self.reversed = reversed + self.head_index = 0 if self.reversed else -1 + + def __contains__(self, value: T) -> bool: + result = value in self.activity_stack + return result + + def __call__(self, value: T) -> ContextManager: + @contextmanager + def fn(): + try: + if self.no_duplicates and value in self.activity_stack: + raise ActivityContextDuplicateException( + f"Activity stack cannot have a duplicate of item {value}" + ) + + self.activity_stack.insert(self.head_index, value) + + if self.max_depth is not None and len(self) > self.max_depth: + raise ActivityContextMaxDepthException( + f"Activity stack exceeds max depth of {self.max_depth}" + ) + + yield + finally: + assert self.is_active() + self.activity_stack.pop(self.head_index) + + return fn() + + def __len__(self) -> int: + result = len(self.activity_stack) + return result + + @overload + def __getitem__(self, key: int) -> T: ... + @overload + def __getitem__(self, key: slice) -> Sequence[T]: ... + def __getitem__(self, key: int | slice) -> T | Sequence[T]: + result = self.activity_stack[key] + return result + + def is_active(self) -> bool: + result = len(self) > 0 + return result + + def get_active(self) -> Optional[T]: + if self.is_active: + return self.activity_stack[-1] + else: + return None + + +def is_submodule_of(module_name: str, other_module_name: str) -> bool: + result = module_name.startswith(f"{other_module_name}.") or ( + module_name != "" and other_module_name == "" + ) + return result + + +def is_submodule_or_same(module_name: str, other_module_name: str) -> bool: + result = module_name == other_module_name or is_submodule_of(module_name, other_module_name) + return result + + +def reduce_losses(losses: Iterable[Tensor]) -> Tensor: + total_loss = None + for loss in losses: + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + if total_loss is None: + return torch.Tensor(torch.nan) + + return total_loss + + +fake_mode = FakeTensorMode( + allow_non_fake_inputs=True, + # allow_fallback_kernels=False, +) + + +@overload +def fake_tensor(t: Tensor, *, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ... + + +@overload +def fake_tensor( + size: Sequence[int] | torch.Size, *, dtype: Optional[torch.dtype] = None, use_meta=False +) -> Tensor: ... + + +@overload +def fake_tensor(*args: int, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ... + + +class MyFakeTensor(Tensor): + @dynamo_disable + def __init__(self, *args, **kwargs): + super().__init__() + self._t: FakeTensor + + @override + @dynamo_disable + def __repr__(self, *, tensor_contents=None): + return f"MyFakeTensor(shape={list(self._t.shape)}, dtype={self._t.dtype}, device={self._t.device})" + + @classmethod + @override + @dynamo_disable + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args, kwargs = pytree.tree_map_only(MyFakeTensor, lambda t: t._t, (args, kwargs)) + + types = pytree.tree_map_only(type(MyFakeTensor), lambda t: FakeTensor, types) + + out = func(*args, **kwargs) + + out = pytree.tree_map_only(Tensor, lambda t: MyFakeTensor.create(t), out) + + return out + + __torch_function__ = torch._C._disabled_torch_function_impl + + # @dynamo_disable + # def __getattribute__(self, attr: str): + # if attr in {'_t', 'device', '__repr__', '__torch_function__', '__class__'}: + # return object.__getattribute__(self, attr) + + # result = getattr(self._t, attr) + + # result = pytree.tree_map_only( + # Tensor, lambda t: MyFakeTensor.create(t), result + # ) + # print('__getattribute__', 'attr', attr, 'ret', result) + + # return result + + @property + @dynamo_disable + def device(self): + return self._t.device + + # @property + # @dynamo_disable + # def shape(self): + # return self._t.shape + + # @dynamo_disable + # def size(self): + # return self._t.size() + + # @classmethod + # @dynamo_disable + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # if kwargs is None: + # kwargs = {} + + # args, kwargs = pytree.tree_map_only( + # MyFakeTensor, lambda t: t._t, (args, kwargs) + # ) + + # ret = func(*args, **kwargs) + + # ret = pytree.tree_map_only( + # Tensor, lambda t: MyFakeTensor.create(t), ret + # ) + # print('__torch_function__', 'func', func, 'ret', ret) + + # return ret + + @staticmethod + @dynamo_disable + def __new__(cls, elem, device) -> MyFakeTensor: + self = torch.Tensor._make_subclass( + cls, + elem, + elem.requires_grad, + dispatch_device=True, + device_for_backend_keys=device, + ) + return cast(MyFakeTensor, self) + + @classmethod + @dynamo_disable + def create(cls, data: Tensor) -> MyFakeTensor: + if isinstance(data, MyFakeTensor): + return data + + if isinstance(data, FakeTensor): + t = data + else: + t = FakeTensor.from_tensor(data, fake_mode=fake_mode) + + # my_fake_tensor = MyFakeTensor(torch.empty(t.shape, dtype=t.dtype, device='meta')) + my_fake_tensor = MyFakeTensor( + torch.empty(t.shape, dtype=t.dtype, device="meta"), + t.device, + ) + my_fake_tensor._t = t + + return my_fake_tensor + + +@dynamo_disable +def fake_tensor(*args, **kwargs) -> Tensor: + dtype: Optional[torch.dtype] = kwargs.get("dtype") + use_meta = kwargs.get("use_meta", False) + + if len(args) == 1 and isinstance(args[0], Tensor): + if use_meta: + fake_tensor = torch.empty(args[0].size(), dtype=dtype or args[0].dtype, device="meta") + else: + fake_tensor = MyFakeTensor.create(args[0]) + else: + fake_tensor = torch.empty(*args, dtype=dtype, device="meta") + if not use_meta: + fake_tensor = MyFakeTensor.create(fake_tensor) + + return fake_tensor + + +@dynamo_skip +def fake_tensor_like(t: Tensor, use_meta=False) -> Tensor: + return fake_tensor(t, use_meta=use_meta) + + +T = TypeVar("T") + + +@dynamo_skip +def fake_tensors(value: T, use_meta=False) -> T: + result = pytree.tree_map_only(Tensor, lambda t: fake_tensor_like(t, use_meta), value) + return result + # if isinstance(value, Mapping): + # return cast(Any, value.__class__)({k: fake_tensors(v, use_meta) for k, v in value.items()}) + # if isinstance(value, Sequence): + # return cast(Any, value.__class__)([fake_tensors(v, use_meta) for v in value]) + # if isinstance(value, Tensor): + # return fake_tensor_like(value, use_meta) + # return value + + +@dynamo_skip +def real_tensors(value: Any) -> Any: + result = pytree.tree_map_only(Tensor, lambda t: None if is_fake_tensor(t) else t, value) + return result + # if isinstance(value, Mapping): + # return cast(Any, value.__class__)({k: real_tensors(v) for k, v in value.items()}) + # if isinstance(value, Sequence): + # return cast(Any, value.__class__)([real_tensors(v) for v in value]) + # if is_fake_tensor(value): + # return None + # return value + + +@dynamo_skip +def is_fake_tensor(t: Any) -> bool: + return isinstance(t, (MyFakeTensor, FakeTensor)) or (isinstance(t, Tensor) and t.is_meta) + + +@dynamo_skip +def has_fake_tensor(v: Any) -> bool: + result = pytree.tree_any(is_fake_tensor, v) + return result + + +@dynamo_skip +def is_real_tensor(t: Any) -> bool: + return isinstance(t, Tensor) and not t.is_meta and not isinstance(t, FakeTensor) + + +@dynamo_skip +def get_parent_module_name(module_name: str): + if "." not in module_name: + return "" + else: + return module_name.rsplit(".", 1)[0] + + +@dynamo_skip +def get_parent_module_names(module_name: str): + parent_module_names = set[str]() + + while len(module_name) > 0: + module_name = get_parent_module_name(module_name) + parent_module_names.add(module_name) + + return parent_module_names + + +def distributed_isend_obj( + obj: Any, + dst: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +) -> list[Optional[torch.distributed.Work]]: + obj_tensor, obj_size_tensor = torch.distributed.distributed_c10d._object_to_tensor( + obj, device="cpu", **_get_group_kwarg_if_necessary() + ) + works: list[Optional[torch.distributed.Work]] = [ + torch.distributed.isend(obj_size_tensor, dst, group), + torch.distributed.isend(obj_tensor, dst, group), + ] + # p2p_ops = [ + # torch.distributed.P2POp(torch.distributed.isend, obj_size_tensor, dst, group), + # torch.distributed.P2POp(torch.distributed.isend, obj_tensor, dst, group), + # ] + + # works = torch.distributed.batch_isend_irecv(p2p_ops) + + return works + + +def distributed_send_obj( + obj: Any, + dst: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +): + works = distributed_isend_obj(obj=obj, dst=dst, group=group) + for work in works: + if work is not None: + work.wait() + + +def distributed_recv_obj( + src: Optional[int] = None, + group: Optional[torch.distributed.ProcessGroup] = None, +) -> Any: + obj_size_tensor = torch.LongTensor(1, device="cpu") + torch.distributed.recv(obj_size_tensor, src=src, group=group) + obj_size = int(obj_size_tensor.item()) + + obj_tensor = torch.ByteTensor(obj_size, device="cpu") + torch.distributed.recv(obj_tensor, src=src, group=group) + + obj = torch.distributed.distributed_c10d._tensor_to_object( + obj_tensor, obj_size, **_get_group_kwarg_if_necessary() + ) + + return obj + + +def _get_group_kwarg_if_necessary() -> dict: + """For newer versions of torch""" + arg_names = inspect.signature( + torch.distributed.distributed_c10d._object_to_tensor + ).parameters.keys() + return dict(group=None) if "group" in arg_names else dict() diff --git a/modelopt/torch/_compress/tools/__init__.py b/modelopt/torch/_compress/tools/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/_compress/tools/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/tools/kd_model.py b/modelopt/torch/_compress/tools/kd_model.py new file mode 100644 index 0000000000..437eb51ca2 --- /dev/null +++ b/modelopt/torch/_compress/tools/kd_model.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Knowledge distillation loss functions. + +Provides normalized_mse_loss and cosine_embedding_loss_batched for comparing +model outputs. Used by validation.py. +""" +# mypy: ignore-errors + +from abc import ABCMeta, abstractmethod +from typing import List, Callable, Literal, Tuple, Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +def normalized_mse_loss( + input: Tensor, + target: Tensor, + reduction: Literal["none", "mean", "sum"] = "mean", + epsilon: float = 1e-6, +) -> Tensor: + loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction=reduction + ) + return loss + + +def cosine_embedding_loss_batched(input: Tensor, target: Tensor) -> Tensor: + # inputs are of shape (B,T,H) + batch_size = input.size(0) + input = input.view(batch_size, -1) + target = target.view(batch_size, -1) + target_tensor = input.new(input.size(0)).fill_(1) + loss = F.cosine_embedding_loss( + input1=input, input2=target, target=target_tensor, reduction="none" + ) + return loss diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index a27cd50771..549ee9a88c 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -29,7 +29,7 @@ import torch.distributed import torch.nn as nn from huggingface_hub import split_torch_state_dict_into_shards -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch._compress.tools.logger import mprint from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file @@ -41,6 +41,7 @@ from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, + DeciLMForCausalLM, rope_type_to_class, ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index 37a49ed236..0e745a0646 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -32,10 +32,13 @@ PreTrainedModel, PreTrainedTokenizerBase, ) -from utils.data.dataloaders import create_validation_dataloader -from utils.parsing import simple_parse_args_string -from utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline -from utils.validation import calculate_losses +from modelopt.torch._compress.utils.data.dataloaders import create_validation_dataloader +from modelopt.torch._compress.utils.parsing import simple_parse_args_string +from modelopt.torch._compress.utils.validate_runtime_pipeline import ( + HiddenStatesAndLMHead, + calculate_losses_pipeline, +) +from modelopt.torch._compress.utils.validation import calculate_losses from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( register_activation_hooks, @@ -162,7 +165,7 @@ def validate_model( ) # Create checkpoint manager with hooks - from utils.checkpoint_manager import ScoringCheckpointManager + from modelopt.torch._compress.utils.checkpoint_manager import ScoringCheckpointManager mprint( f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}" diff --git a/modelopt/torch/_compress/utils/checkpoint_manager.py b/modelopt/torch/_compress/utils/checkpoint_manager.py new file mode 100644 index 0000000000..318586ba44 --- /dev/null +++ b/modelopt/torch/_compress/utils/checkpoint_manager.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Checkpoint manager for activation hook scoring with periodic saves and resume support. +""" + +import json +import time +from pathlib import Path +from typing import Dict, Any, Optional +from modelopt.torch._compress.tools.logger import mprint, aprint + + +class ScoringCheckpointManager: + """Manages checkpointing for activation hook scoring with periodic saves.""" + + def __init__( + self, checkpoint_dir: str, runtime, activation_hooks=None, checkpoint_interval: int = 100 + ): + """ + Initialize checkpoint manager. + + Args: + checkpoint_dir: Directory to save checkpoints + runtime: Runtime object for distributed processing + activation_hooks: Dictionary of activation hooks to manage + checkpoint_interval: Save checkpoint every N batches + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.runtime = runtime + self.activation_hooks = activation_hooks + self.checkpoint_interval = checkpoint_interval + self.rank = runtime.global_rank if runtime is not None else 0 + self.is_main_process = runtime is None or runtime.is_main_process + + # Debug: Log checkpoint manager initialization + hook_count = len(activation_hooks) if activation_hooks else 0 + aprint( + f"[Rank {self.rank}] Checkpoint manager initialized: {hook_count} hooks, dir: {checkpoint_dir}" + ) + + # Checkpoint files + self.progress_file = self.checkpoint_dir / "scoring_progress.json" + self.hook_states_file = self.checkpoint_dir / f"hook_states_rank_{self.rank}.pth" + + # Progress tracking + self.current_batch = 0 + self.total_batches = 0 + self.start_time = time.time() + + # Ensure directory exists + if self.is_main_process: + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + def load_checkpoint(self) -> Optional[Dict[str, Any]]: + """ + Load existing checkpoint if available, including hook states. + + Returns: + Dict with checkpoint info or None if no checkpoint exists + """ + aprint(f"[Rank {self.rank}] Looking for checkpoint at: {self.progress_file}") + if not self.progress_file.exists(): + aprint(f"[Rank {self.rank}] No checkpoint file found at {self.progress_file}") + return None + + try: + with open(self.progress_file, "r") as f: + checkpoint_data = json.load(f) + + # Validate checkpoint + if "current_batch" in checkpoint_data and "total_batches" in checkpoint_data: + self.current_batch = checkpoint_data["current_batch"] + self.total_batches = checkpoint_data["total_batches"] + + mprint( + f"Found checkpoint: batch {self.current_batch}/{self.total_batches} ({checkpoint_data.get('progress', 0.0):.1%})" + ) + mprint( + f"Will resume from batch {self.current_batch}, skipping batches 0-{self.current_batch - 1}" + ) + + # Load hook states if hooks are available + if self.activation_hooks is not None: + success = self.load_hook_states(self.activation_hooks) + if success: + aprint( + f"[Rank {self.rank}] Successfully loaded hook states from checkpoint" + ) + else: + aprint(f"[Rank {self.rank}] Failed to load hook states - starting fresh") + + return checkpoint_data + else: + aprint( + f"[Rank {self.rank}] Invalid checkpoint format (missing current_batch/total_batches): {checkpoint_data}" + ) + return None + + except (json.JSONDecodeError, KeyError) as e: + mprint(f"Error loading checkpoint: {e}") + + return None + + def load_hook_states(self, activation_hooks) -> bool: + """ + Load hook states from checkpoint files. + + Args: + activation_hooks: Hook objects to load states into + + Returns: + bool: True if hook states were successfully loaded, False otherwise + """ + import os + + # Each rank loads only its own hook states + current_rank = int(os.environ.get("RANK", 0)) + hook_states_path = self.checkpoint_dir / f"hook_states_rank_{current_rank}.pth" + + if hook_states_path.exists(): + aprint(f"[Rank {current_rank}] Loading hook states from {hook_states_path}") + try: + import torch + + hook_states = torch.load(hook_states_path, map_location="cpu") + + # Load states into corresponding hooks + loaded_count = 0 + for module_name, hook in activation_hooks.items(): + if module_name in hook_states: + hook.load_state(hook_states[module_name]) + loaded_count += 1 + + # Log progress info if available (only for a few hooks to avoid spam) + if loaded_count <= 3: # Only log first few hooks + progress_info = hook.get_progress_info() + if progress_info: + aprint(f"[Rank {current_rank}] {module_name}: {progress_info}") + else: + aprint( + f"[Rank {current_rank}] Warning: No saved state found for hook: {module_name}" + ) + + aprint( + f"[Rank {current_rank}] Successfully loaded states for {loaded_count}/{len(activation_hooks)} hooks" + ) + return True + + except Exception as e: + aprint(f"[Rank {current_rank}] Error loading hook states: {e}") + return False + else: + aprint(f"[Rank {current_rank}] No hook states file found at {hook_states_path}") + return False + + def should_skip_batch(self, batch_idx: int) -> bool: + """Check if we should skip this batch (already processed in previous run).""" + should_skip = batch_idx < self.current_batch + if should_skip and batch_idx % 10 == 0: # Log every 10th skipped batch to avoid spam + mprint(f"Skipping batch {batch_idx} (resume from batch {self.current_batch})") + return should_skip + + def update_progress(self, batch_idx: int, total_batches: int): + """ + Update progress and potentially save checkpoint. + + Args: + batch_idx: Current batch index + total_batches: Total number of batches + """ + self.current_batch = batch_idx + self.total_batches = total_batches + + # Save checkpoint periodically or on completion + should_save = ( + (batch_idx % self.checkpoint_interval == 0) # Periodic save + or (batch_idx == total_batches - 1) # Final batch + ) + + if should_save: + # All ranks save their hook states + if self.activation_hooks is not None: + try: + from modelopt.torch._compress.activation_scoring.activation_hooks.hooks import ( + ActivationsHook, + ) + + saved_path = ActivationsHook.save_hook_states( + self.activation_hooks, self.checkpoint_dir, self.runtime + ) + except Exception as e: + mprint(f"Warning: Failed to save hook states: {e}") + + # Only main process saves progress info + if self.is_main_process: + self.save_checkpoint() + + # Synchronize all ranks after checkpointing + if self.runtime is not None: + self.runtime.wait_for_everyone() + + def save_checkpoint(self): + """ + Save current checkpoint to disk (progress info only). + Hook states are saved separately in update_progress. + """ + try: + # Save progress + progress_data = { + "current_batch": self.current_batch, + "total_batches": self.total_batches, + "progress": self.current_batch / self.total_batches + if self.total_batches > 0 + else 0.0, + "timestamp": time.time(), + "elapsed_time": time.time() - self.start_time, + "rank": self.rank, + } + + # Write progress atomically + temp_file = self.progress_file.with_suffix(".tmp") + with open(temp_file, "w") as f: + json.dump(progress_data, f, indent=2) + temp_file.replace(self.progress_file) + + # Hook states are saved at a higher level to ensure all ranks participate + + if self.current_batch % (self.checkpoint_interval) == 0: + progress_pct = progress_data["progress"] * 100 + elapsed = progress_data["elapsed_time"] + mprint( + f"Checkpoint saved: batch {self.current_batch}/{self.total_batches} ({progress_pct:.1f}%), elapsed: {elapsed:.1f}s" + ) + + except Exception as e: + mprint(f"Error saving checkpoint: {e}") + + def finalize(self): + """Mark scoring as completed.""" + # All ranks save their final hook states + if self.activation_hooks is not None: + try: + from modelopt.torch._compress.activation_scoring.activation_hooks.hooks import ( + ActivationsHook, + ) + + saved_path = ActivationsHook.save_hook_states( + self.activation_hooks, self.checkpoint_dir, self.runtime + ) + mprint(f"Final hook states saved to {saved_path}") + except Exception as e: + mprint(f"Warning: Failed to save final hook states: {e}") + + # Only main process saves progress info + if self.is_main_process: + self.current_batch = self.total_batches + self.save_checkpoint() + mprint(f"Scoring completed and finalized: {self.total_batches} batches processed") + + # Synchronize all ranks after finalization + if self.runtime is not None: + self.runtime.wait_for_everyone() diff --git a/modelopt/torch/_compress/utils/data/dataloaders.py b/modelopt/torch/_compress/utils/data/dataloaders.py new file mode 100644 index 0000000000..4c4fce0606 --- /dev/null +++ b/modelopt/torch/_compress/utils/data/dataloaders.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DataLoader utilities for language model training and validation. +""" + +import os +from collections.abc import Callable, Mapping, Sequence +from functools import partial +from typing import Protocol, TypeVar + +import datasets +import torch +import torch.distributed +from accelerate import Accelerator +from modelopt.torch._compress.tools.logger import mprint +from torch.utils.data import DataLoader, Dataset, IterableDataset +from torch.utils.data._utils.collate import collate, default_collate_fn_map +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase +from modelopt.torch._compress.utils.data.dataset import ConstantLengthDataset + + +def collate_none_fn( + batch, *, collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None +): + return None + + +collate_fn_map_with_none_support = {**default_collate_fn_map, type(None): collate_none_fn} +collate_fn_with_none_support = partial(collate, collate_fn_map=collate_fn_map_with_none_support) + + +class LoadDatasetFn(Protocol): + def __call__( + self, dataset_path: str, content_field: str, keep_in_memory: bool = False + ) -> Mapping[str, Dataset]: ... + + +def load_from_disk_fn( + dataset_path: str, content_field: str, keep_in_memory: bool = False +) -> Mapping[str, Dataset]: + return datasets.load_from_disk(dataset_path, keep_in_memory=keep_in_memory) + + +def load_streaming_fn( + dataset_path: str, content_field: str, keep_in_memory: bool = False +) -> Mapping[str, Dataset]: + dataset = datasets.load_dataset( + dataset_path, + streaming=True, + features=datasets.Features( + { + content_field: datasets.Value(dtype="string"), + } + ), + keep_in_memory=keep_in_memory, + ) + + return dataset + + +def create_train_dataloader( + accelerator: Accelerator, + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name="train", + keep_in_memory: bool = False, + shuffle_train_data_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + varlen: bool = True, +): + mprint(f"\ncreate_train_dataloader on rank {accelerator.process_index}") + if isinstance(dataset, str): + dataset = load_dataset_fn(dataset, content_field, keep_in_memory) + + train_data = dataset[dataset_name] + if shuffle_train_data_seed is not None: + train_data = train_data.shuffle(seed=shuffle_train_data_seed) + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + infinite=True, + seq_length=block_size * micro_batch_size if varlen else block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + # return_cu_seqlens=varlen, + # seqlen_cap=block_size if varlen else None + ) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1 if varlen else micro_batch_size, + pin_memory=True, + collate_fn=collate_fn_with_none_support, + num_workers=os.cpu_count() // 2 // 8, + ) + + return train_dataloader + + +def create_validation_dataloader( + accelerator: Accelerator | None, + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + eval_samples: int | None = None, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "__auto__", + keep_in_memory: bool = False, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + varlen: bool = True, + shuffle_seed: int | None = None, +): + if accelerator is None: + accelerator = Printer() + + if accelerator.is_main_process: + if isinstance(dataset, str): + dataset = load_dataset_fn(dataset, content_field, keep_in_memory) + + if isinstance(dataset, datasets.Dataset | torch.utils.data.Dataset): + valid_data = dataset + mprint( + "#### Path to specific dataset was given (not DatasetDict), taking it as-is ####" + ) + else: + assert isinstance(dataset, datasets.DatasetDict) + if dataset_name == "__auto__": + val_split_options = [] + for val_key_prefix in ("val", "test"): + if len(val_split_options) == 0: + val_split_options = [ + split + for split in dataset # DatasetDict is dict-like and supports direct iteration + if split.lower().startswith(val_key_prefix) + ] + assert len(val_split_options) == 1, ( + f"Expected exactly one validation split, got {val_split_options=} ({dataset.keys()=})" + ) + val_split = val_split_options[0] + mprint(f"Inferred validation split automatically: '{val_split}'") + else: + val_split = dataset_name + mprint(f"Validation split explicitly chosen: '{val_split}'") + valid_data = dataset[val_split] + + if shuffle_seed is not None: + mprint(f"Shuffling with {shuffle_seed=}") + valid_data = valid_data.shuffle(seed=shuffle_seed) + + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + infinite=False, + seq_length=block_size * micro_batch_size if varlen else block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + # return_cu_seqlens=varlen, + # seqlen_cap=block_size if varlen else None + ) + if varlen and eval_samples is not None: + eval_samples = eval_samples // micro_batch_size + val_offloaded_dataset = realize_dataset_in_memory(valid_dataset, eval_samples) + + valid_data_len = len(val_offloaded_dataset) + mprint(f"num validation examples = {valid_data_len}") + else: + val_offloaded_dataset = None + + if not isinstance(accelerator, Printer): + obj_list = [val_offloaded_dataset] + torch.distributed.broadcast_object_list(obj_list) + val_offloaded_dataset = obj_list[0] + + # let accelerate prepare to handle distributed sampling + val_dataloader = DataLoader( + val_offloaded_dataset, + batch_size=1 if varlen else micro_batch_size, + pin_memory=True, + collate_fn=collate_fn_with_none_support, + ) + + return val_dataloader + + +def realize_dataset_in_memory(dataset: IterableDataset, eval_samples: int | None) -> list[dict]: + tqdm_desc = f"realize_dataset_in_memory({eval_samples=})" + if eval_samples is None: + offloaded_dataset = list(tqdm(dataset, desc=tqdm_desc)) + else: + val_iter = iter(dataset) + offloaded_dataset = [next(val_iter) for _ in tqdm(range(eval_samples), desc=tqdm_desc)] + return offloaded_dataset + + +def create_dataloaders( + accelerator: Accelerator, + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset_path: str, + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + val_micro_batch_size: int | None = None, + eval_samples: int | None = None, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + train_dataset_name: str = "train", + val_dataset_name: str = "__auto__", + disable_validation: bool = False, + keep_in_memory: bool = False, + shuffle_train_data_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + varlen: bool = True, +): + if val_micro_batch_size is None: + val_micro_batch_size = micro_batch_size + + dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory=keep_in_memory) + + train_dataloader = create_train_dataloader( + accelerator, + seed, + tokenizer, + block_size, + dataset, + content_field, + fim_rate, + fim_spm_rate, + micro_batch_size, + load_dataset_fn, + train_dataset_name, + shuffle_train_data_seed=shuffle_train_data_seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + varlen=varlen, + ) + + if not disable_validation: + val_dataloader = create_validation_dataloader( + accelerator, + seed, + tokenizer, + block_size, + dataset, + content_field, + fim_rate, + fim_spm_rate, + val_micro_batch_size, + eval_samples, + load_dataset_fn, + val_dataset_name, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + varlen=varlen, + ) + else: + val_dataloader = None + + return train_dataloader, val_dataloader + + +TensorT = TypeVar("TensorT", bound=torch.Tensor) + + +@torch.no_grad() +def create_padded_tensor( + tensor: TensorT, desired_shape: Sequence[int], padding_value: float = 0 +) -> TensorT: + if tensor.shape == torch.Size(desired_shape): + return tensor + + padded_tensor = torch.full( + desired_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device + ) + indices = torch.where(torch.ones_like(tensor, dtype=torch.bool)) + padded_tensor[indices] = tensor.view(-1) + return padded_tensor + + +class Printer: + is_main_process = True + process_index = None + + @staticmethod + def print(*args, **kwargs) -> None: + print(*args, **kwargs) diff --git a/modelopt/torch/_compress/utils/data/dataset.py b/modelopt/torch/_compress/utils/data/dataset.py new file mode 100644 index 0000000000..2c7fcef09a --- /dev/null +++ b/modelopt/torch/_compress/utils/data/dataset.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import functools +from typing import Optional +from typing import Sequence + +import numpy as np +import torch +from torch.utils.data import IterableDataset + +from modelopt.torch._compress.tools.logger import aprint, mprint + +FIM_TOKEN_START = "", "middle>", "suffix>", "pad>"] +CODEGEN_FIM_TOKENS = ["", "<|endoftext|>", ""] + + +class ConstantLengthDataset(IterableDataset): + """ + Iterable dataset that returns constant length chunks of tokens from stream of text files. + Args: + tokenizer (Tokenizer): The processor used for proccessing the data. + dataset (dataset.Dataset): Dataset with text files. + infinite (bool): If True the iterator is reset after dataset reaches end else stops. + seq_length (int): Length of token sequences to return. + num_of_sequences (int): Number of token sequences to keep in buffer. + chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. + fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM. + fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM. + seed (int): Seed for random number generator. + label_shift (bool): Whether to shift labels by 1 or not. + """ + + def __init__( + self, + tokenizer, + dataset, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + content_field="content", + fim_rate=0.5, + fim_spm_rate=0.5, + seed=0, + label_shift=True, + max_sample_length=200_000, + tokens_field="token_ids", + source_datasets_to_discard: Optional[Sequence[str]] = tuple(), + bos_rate: float = 1.0, + return_cu_seqlens: bool = False, + seqlen_cap: Optional[int] = None, + ): + self.tokenizer = tokenizer + self.concat_token_id = tokenizer.eos_token_id + # self.concat_token_id = tokenizer.eos_id # for lit-lamma tokenizer + self.dataset = dataset + self.is_dataset_already_tokenized = tokens_field in self.dataset.column_names + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + if not self.is_dataset_already_tokenized: + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.max_sample_length = max_sample_length + else: + self.max_buffer_size = seq_length * num_of_sequences + # self.max_sample_length = int(max_sample_length / chars_per_token) + self.max_sample_length = max_sample_length # we don't know the exact chars_per_token + self.content_field = content_field + self.tokens_field = tokens_field + self.fim_rate = fim_rate + self.fim_spm_rate = fim_spm_rate + self.seed = seed + self.max_sample_length = max_sample_length + + self.fim_token_ids = get_fim_token_ids(self.tokenizer) + if None in self.fim_token_ids.values() and self.fim_rate > 0: + self.fim_rate = 0 + self.label_shift = label_shift + self.bos_rate = bos_rate + self.source_datasets_to_discard = ( + source_datasets_to_discard if source_datasets_to_discard is not None else tuple() + ) + self.return_cu_seqlens = return_cu_seqlens + self.seqlen_cap = seqlen_cap + self.np_rng = np.random.RandomState(seed=self.seed) + + def __iter__(self) -> dict[str, torch.Tensor]: + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + sample = next(iterator) + if ( + len(self.source_datasets_to_discard) > 0 + and sample["dataset_name"] in self.source_datasets_to_discard + ): + continue + if not self.is_dataset_already_tokenized: + sample = sample[self.content_field] + if ( + isinstance(sample, list) + and isinstance(sample[0], dict) + and {"content", "role"}.issubset(sample[0]) + ): + if len(sample) > 1: + sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + else: + sample = sample[0]["content"] + else: + sample = sample[self.tokens_field] + sample = sample[: self.max_sample_length] + buffer.append(sample) + buffer_len += len(sample) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + else: + more_examples = False + break + + if not self.is_dataset_already_tokenized: + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] + else: + tokenized_inputs = buffer + + all_token_ids = [] + + for tokenized_input in tokenized_inputs: + if ( + self.bos_rate < 1.0 + and not self.np_rng.binomial(1, self.bos_rate) + and self.tokenizer.bos_token_id is not None + and tokenized_input[0] == self.tokenizer.bos_token_id + ): + tokenized_input = tokenized_input[1:] + # optionally do FIM permutations + if self.fim_rate > 0: + tokenized_input, np_rng = permute( + sample=tokenized_input, + np_rng=self.np_rng, + fim_token_ids=self.fim_token_ids, + fim_rate=self.fim_rate, + fim_spm_rate=self.fim_spm_rate, + truncate_or_pad=False, + ) + + all_token_ids.extend(tokenized_input + [self.concat_token_id]) + + examples = [] + # cuts code snippets in the middle to yield constant length instances + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + labels = all_token_ids[ + i + int(self.label_shift) : i + int(self.label_shift) + self.seq_length + ] + # ignores last short example in the buffer + if len(labels) == self.seq_length: + examples.append((input_ids, labels)) + + shuffling_indices = self.np_rng.permutation(len(examples)) + examples = [examples[i] for i in shuffling_indices] + + for input_ids, labels in examples: + self.current_size += 1 + input_ids = torch.LongTensor(input_ids) + if self.return_cu_seqlens: + cu_seqlens = self.prepare_cu_seqlens(input_ids) + yield { + "input_ids": input_ids, + "targets": torch.LongTensor(labels), + "cu_seqlens": cu_seqlens, + } + else: + yield { + "input_ids": input_ids, + "targets": torch.LongTensor(labels), + } + + def prepare_cu_seqlens(self, input_ids): + if not self.return_cu_seqlens: + return None + # seqlens is of shape (num_seqs+1,) and with the property that + # the i-th sequnce is input_ids[seqlens[i-1]:seqlens[i]] + cu_seqlens = (input_ids == self.concat_token_id).nonzero().squeeze(-1).int() + 1 + cu_seqlens = torch.cat( + ( + torch.IntTensor([0]), + cu_seqlens, + torch.IntTensor([len(input_ids)]), + ) + ) + if self.seqlen_cap is not None: + i = 1 + while i < len(cu_seqlens): + curr_seqlen = cu_seqlens[i] - cu_seqlens[i - 1] + if curr_seqlen > self.seqlen_cap: + cu_seqlens = torch.cat( + (cu_seqlens[:i], cu_seqlens[[i - 1]] + self.seqlen_cap, cu_seqlens[i:]) + ) + i += 1 + if cu_seqlens[-1] == cu_seqlens[-2]: + cu_seqlens = cu_seqlens[:-1] + return cu_seqlens + + +## Adapted from https://github.com/NVIDIA/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py +def permute( + sample, + np_rng, + fim_token_ids, + fim_rate=0.5, + fim_spm_rate=0.5, + truncate_or_pad=False, +): + """ + Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes: + PSM and SPM (with a probability of fim_spm_rate). + """ + + if np_rng.binomial(1, fim_rate): + boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2)) + boundaries.sort() + + prefix = np.array(sample[: boundaries[0]], dtype=np.int64) + middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64) + suffix = np.array(sample[boundaries[1] :], dtype=np.int64) + + if truncate_or_pad: + raise NotImplementedError + + if "" in fim_token_ids: # use codegen FIM pattern + assert fim_spm_rate == 0 + new_sample = np.concatenate( + [ + prefix, + [fim_token_ids[""]], + suffix, + [fim_token_ids["<|endoftext|>"]], + [fim_token_ids[""]], + [fim_token_ids[""]], + middle, + ] + ) + elif np_rng.binomial(1, fim_spm_rate): + # SPM (variant 2 from FIM paper) + new_sample = np.concatenate( + [ + [fim_token_ids["prefix_tok_id"], fim_token_ids["suffix_tok_id"]], + suffix, + [fim_token_ids["middle_tok_id"]], + prefix, + middle, + ] + ) + else: + # PSM + new_sample = np.concatenate( + [ + [fim_token_ids["prefix_tok_id"]], + prefix, + [fim_token_ids["suffix_tok_id"]], + suffix, + [fim_token_ids["middle_tok_id"]], + middle, + ] + ) + else: + # don't do FIM preproc + new_sample = sample + + return list(new_sample), np_rng + + +# this is expensive so we cache it +@functools.lru_cache(maxsize=None) +def get_fim_token_ids(tokenizer): + # ugly fix for Salesforce/codegen25-7b-multi tokenizer + if hasattr(tokenizer, "encoder"): + search_vocab = tokenizer.encoder._special_tokens + fim_token_ids = {tok: search_vocab.get(tok, None) for tok in CODEGEN_FIM_TOKENS} + else: + search_vocab = tokenizer.vocab + if (FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + FIM_TOKEN_END_LIST[0]) in search_vocab: + prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = ( + search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + tok, None) + for tok in FIM_TOKEN_END_LIST + ) + else: + prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = ( + search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_SANTA + tok, None) + for tok in FIM_TOKEN_END_LIST + ) + fim_token_ids = { + "suffix_tok_id": suffix_tok_id, + "prefix_tok_id": prefix_tok_id, + "middle_tok_id": middle_tok_id, + "pad_tok_id": pad_tok_id, + } + return fim_token_ids diff --git a/modelopt/torch/_compress/utils/parsing.py b/modelopt/torch/_compress/utils/parsing.py new file mode 100644 index 0000000000..97f698ba91 --- /dev/null +++ b/modelopt/torch/_compress/utils/parsing.py @@ -0,0 +1,455 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parsing and formatting utilities for configuration handling in model compression. + +This module provides utilities for: +- Parsing command-line arguments and configuration strings +- Formatting and displaying model configurations (block configs, attention, FFN) +- Formatting loss metrics for logging and visualization +""" +# mypy: ignore-errors + +import json +from pathlib import Path +from typing import Any + +import torch +from omegaconf import DictConfig + + +def handle_arg_string(arg): + if arg.lower() == "true": + return True + elif arg.lower() == "false": + return False + elif arg.isnumeric(): + return int(arg) + try: + return float(arg) + except ValueError: + return arg + + +def simple_parse_args_string(args_string): + """ + Parses something like + args1=val1,arg2=val2 + Into a dictionary + """ + if args_string is None: + return {} + args_string = args_string.strip() + if not args_string: + return {} + arg_list = [arg for arg in args_string.split(",") if arg] + args_dict = {k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]} + return args_dict + + +def parse_json(s: str | None) -> Any: + if s is None: + return None + return json.loads(s) + + +def parse_path(s: str | None) -> Path | None: + if s is None or s == "": + return None + return Path(s) + + +def parse_dtype(dtype_name: str) -> torch.dtype: + dtype = { + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp32": torch.float32, + "float32": torch.float32, + "fp16": torch.float16, + "float16": torch.float16, + }[dtype_name] + return dtype + + +def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any: + """ + If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"] + """ + value = dictionary + for key in nested_key.split("."): + value = value[key] + return value + + +def format_block_configs(config) -> str: + """ + Formats block_configs from a model configuration into a beautiful, readable string. + + Each line represents a layer with attention and FFN configuration. + + Args: + config: PretrainedConfig object containing block_configs + + Returns: + Formatted string with layer configurations + + Example output: + ╭─────────────────────── Model Architecture ────────────────────────╮ + │ Layer 1 │ Attention: no_op │ FFN: mult = 4.95 │ + │ Layer 2 │ Attention: 4 heads in group │ FFN: mult = 4.95 │ + │ Layer 3 │ Attention: 4 heads in group │ FFN: no_op │ + ╰────────────────────────────────────────────────────────────────────╯ + """ + if not hasattr(config, "block_configs") or not config.block_configs: + return "❌ No block configs found" + + lines = [] + + # Header + header = "╭─────────────────────────────────────── Model Architecture ────────────────────────────────────────╮" + lines.append(header) + + # Format each layer + for i, block in enumerate(config.block_configs, 1): + attention_info = _format_attention_config(block.attention) + ffn_info = _format_ffn_config(block.ffn) + + # Create formatted line with proper padding + layer_str = f"Layer {i:2d}" + attention_str = f"Attention: {attention_info}" + ffn_str = f"FFN: {ffn_info}" + + line = f"│ {layer_str:8s} │ {attention_str:30s} │ {ffn_str:18s} │" + lines.append(line) + + # Footer + footer = "╰────────────────────────────────────────────────────────────────────────────────────────────────────╯" + lines.append(footer) + + return "\n".join(lines) + + +def _format_attention_config(attention_config) -> str: + """Format attention configuration for display with visual indicators.""" + if not attention_config: + return "default" + + if attention_config.no_op: + return "❌ no_op" + + n_heads = attention_config.n_heads_in_group + if n_heads is not None: + return f"{n_heads} heads in group" + + if attention_config.replace_with_linear: + return "linear replacement" + + # Check for other attention types + if attention_config.mamba: + return "🐍 mamba" + if attention_config.llama4: + return "🦙 llama4" + + window_length = attention_config.window_length + if window_length is not None: + return f"windowed ({window_length})" + + if attention_config.sparsify: + return "sparse" + + return "default" + + +def _format_ffn_config(ffn_config) -> str: + """Format FFN configuration for display with visual indicators.""" + if not ffn_config: + return "default" + + if ffn_config.no_op: + return "❌ no_op" + + if ffn_config.replace_with_linear: + return "linear" + + ffn_intermediate = ffn_config.intermediate_size + if ffn_intermediate is not None: + return f"ffn_intermediate = {ffn_intermediate}" + + # Check for MoE configuration + moe_config = ffn_config.moe + if moe_config: + return "MoE" + + if ffn_config.sparsify: + return "sparse" + + return "default" + + +def format_global_config(config: DictConfig, title: str = "Global Configuration") -> str: + """ + Pretty prints a global DictConfig with nice formatting and visual indicators. + + Args: + config: DictConfig object to format + title: Title to display at the top of the formatted output + + Returns: + Formatted string with configuration details + + Example output: + ╭─────────────────── Global Configuration ────────────────────╮ + │ Training │ + │ • learning_rate: 1e-4 │ + │ • batch_size: 32 │ + │ • epochs: 100 │ + │ Model │ + │ • hidden_dim: 512 │ + │ • num_layers: 6 │ + │ Data │ + │ • dataset_path: /path/to/data │ + │ • block_size: 2048 │ + ╰──────────────────────────────────────────────────────────────╯ + """ + if not config: + return "❌ No configuration found" + + lines = [] + + # Calculate box width based on title + box_width = max(60, len(title) + 10) + title_padding = (box_width - len(title) - 2) // 2 + + # Header + header = f"\n╭{'─' * (box_width - 2)}╮" + title_line = ( + f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" + ) + lines.extend([header, title_line]) + + def _format_value(value: Any, indent: int = 0) -> str: + """Format a value with appropriate type indicators.""" + prefix = " " * indent + + if isinstance(value, (bool, int, float)): + return f"{prefix} {value}" + elif isinstance(value, str): + # Show truncated long strings + if len(value) > 50: + return f"{prefix} {value[:47]}..." + return f"{prefix} {value}" + elif isinstance(value, (list, tuple)): + if not value: + return f"{prefix} []" + elif len(value) <= 3: + return f"{prefix} {list(value)}" + else: + return f"{prefix} [{len(value)} items]" + elif value is None: + return f"{prefix} None" + else: + return f"{prefix} {value!s}" + + def _add_config_section(cfg: DictConfig, section_name: str = "", indent: int = 0): + """Recursively add configuration sections.""" + if section_name: + indent_str = " " * indent + section_line = f"│ {indent_str}{section_name}" + # Pad to box width + padding_needed = box_width - len(section_line) - 1 + section_line += " " * padding_needed + "│" + lines.append(section_line) + + for key, value in cfg.items(): + if isinstance(value, DictConfig): + # Nested configuration section + _add_config_section(value, f"{key}", indent + 1) + else: + # Regular key-value pair + indent_str = " " * (indent + 1) + value_str = _format_value(value).replace(" " * 0, "").strip() + line = f"│ {indent_str} {key}: {value_str}" + # Pad to box width + if len(line) >= box_width - 1: + # Truncate long lines + line = line[: box_width - 4] + "..." + padding_needed = box_width - len(line) - 1 + line += " " * padding_needed + "│" + lines.append(line) + + # Add configuration sections + _add_config_section(config) + + # Footer + footer = f"╰{'─' * (box_width - 2)}╯" + lines.append(footer) + + return "\n".join(lines) + + +def format_stitched_losses( + losses_dict: dict[str, float], + best_steps_dict: dict[str, int] | None = None, + best_values_dict: dict[str, float] | None = None, + step_number: int | None = None, + title: str = "Stitched Module Losses", +) -> str: + """ + Pretty prints stitched module losses with comprehensive tracking and visual indicators. + + Args: + losses_dict: Dictionary with block names as keys and current loss values as floats + best_steps_dict: Optional dictionary with block names as keys and best step numbers as values + best_values_dict: Optional dictionary with block names as keys and best loss values as floats + step_number: Optional current step number to include in summary + title: Title to display at the top of the formatted output + + Returns: + Formatted string with loss values in a comprehensive table format + + Example output: + ╭─────────────────── Stitched Module Losses ──────────────────╮ + │ Block │ Loss Value │ Best Step │ Best Value │ Change from avg │ + │───────┼────────────┼───────────┼────────────┼──────────────────│ + │ 00 │ 6.21e-03 │ Step 5 │ 5.95e-03 │ ↑ +2.6e-04 │ + │ 01 │ 5.14e-04 │ Step 12 │ 5.14e-04 │ ↓ -1.2e-04 │ + │ 02 │ 9.84e-05 │ Step 15 │ 9.84e-05 │ ↓ -3.1e-04 │ + ╰──────────────────────────────────────────────────────────────╯ + """ + if not losses_dict: + return "❌ No losses found" + + lines = [] + + # Calculate statistics + loss_values = list(losses_dict.values()) + max_loss = max(loss_values) + min_loss = min(loss_values) + avg_loss = sum(loss_values) / len(loss_values) + + # Calculate box width for new layout (removed Bar column) + box_width = 74 + title_padding = (box_width - len(title) - 2) // 2 + + # Header + header = f"╭{'─' * (box_width - 2)}╮" + title_line = ( + f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" + ) + separator = ( + f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Best Step':<10} │ " + f"{'Best Value':<12} │ {'Change from avg':<18} │" + ) + divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 12}┼{'─' * 14}┼{'─' * 20}│" + + lines.extend([header, title_line, separator, divider]) + + # Format each loss + for block_name, loss_value in losses_dict.items(): + # Format current loss value + loss_str = f"{loss_value:.2e}" + + # Format best step + if best_steps_dict and block_name in best_steps_dict: + best_step_str = f"Step {best_steps_dict[block_name]}" + else: + best_step_str = " --" + + # Format best value + if best_values_dict and block_name in best_values_dict: + best_value = best_values_dict[block_name] + best_value_str = f"{best_value:.2e}" + else: + best_value = loss_value # Assume current is best if no history + best_value_str = f"{best_value:.2e}" + + # Calculate change from average + change_from_avg = loss_value - avg_loss + if abs(change_from_avg) > 1e-8: # Only show if meaningful + change_str = f"{abs(change_from_avg):.1e}" + if change_from_avg > 0: + # Current is above average (worse for loss) + change_display = f"↑ +{change_str}" + else: + # Current is below average (better for loss) + change_display = f"↓ -{change_str}" + else: + # At average value + change_display = "↔ 0.0e+00" + + # Format the line + block_display = block_name.replace("block_", "").zfill(2) + + line = ( + f"│ {block_display:<5} │ {loss_str:<12} │ {best_step_str:<10} │ " + f"{best_value_str:<12} │ {change_display:<18} │" + ) + lines.append(line) + + # Add summary statistics + lines.append(divider) + + # Build summary string with optional step number + summary_parts = [] + if step_number is not None: + summary_parts.append(f"Step {step_number}") + summary_parts.extend([f"Avg={avg_loss:.2e}", f"Max={max_loss:.2e}", f"Min={min_loss:.2e}"]) + + summary_text = ", ".join(summary_parts) + summary = f"│ Summary: {summary_text}" + + # Pad summary to box width + padding_needed = box_width - len(summary) - 1 + summary += " " * padding_needed + "│" + lines.append(summary) + + # Add best step summary if we have best step data + if best_steps_dict and best_values_dict: + # Find the most common best step (modal step) + step_counts = {} + for step in best_steps_dict.values(): + step_counts[step] = step_counts.get(step, 0) + 1 + + if step_counts: + modal_best_step = max(step_counts, key=step_counts.get) + + # Get values at the modal best step for blocks that have it as their best + best_step_values = [] + for block_name, best_step in best_steps_dict.items(): + if best_step == modal_best_step and block_name in best_values_dict: + best_step_values.append(best_values_dict[block_name]) + + if best_step_values: + best_step_avg = sum(best_step_values) / len(best_step_values) + best_step_max = max(best_step_values) + best_step_min = min(best_step_values) + + best_step_summary_text = ( + f"Best: Step {modal_best_step}, Avg={best_step_avg:.2e}, " + f"Max={best_step_max:.2e}, Min={best_step_min:.2e}" + ) + best_step_summary = f"│ {best_step_summary_text}" + + # Pad best step summary to box width + padding_needed = box_width - len(best_step_summary) - 1 + best_step_summary += " " * padding_needed + "│" + lines.append(best_step_summary) + + # Footer + footer = f"╰{'─' * (box_width - 2)}╯" + lines.append(footer) + + return "\n".join(lines) diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py new file mode 100644 index 0000000000..08e1221a72 --- /dev/null +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -0,0 +1,390 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. + +Coordinates forward passes and loss computation through model shards distributed across GPUs +using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation. + +Used by validate_model.py during activation scoring for sharded models. +""" +# mypy: ignore-errors + +from statistics import mean + +import numpy as np +import torch +import torch.distributed +import wandb +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.checkpoint_utils import init_module_with_state_dict +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMForCausalLM, + LMHead, +) +from modelopt.torch._compress.tools.runtime import IRuntime +from modelopt.torch._compress.sewing_kit import ( + ExternalTarget, + InputArgs, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, +) +from modelopt.torch._compress.sewing_kit.core import InputReducer +from modelopt.torch._compress.sewing_kit.utils import ( + distributed_recv_obj, + distributed_send_obj, + fake_tensor, +) +from torch.utils.data import DataLoader +from tqdm import tqdm +from modelopt.torch._compress.tools.sharded_checkpoint_utils import DummyBlock +from modelopt.torch._compress.utils.validation import _organize_outputs, calculate_batch_outputs + + +@torch.no_grad() +def validate_pipeline_inner( + runtime: IRuntime, + stitched_model: StitchedModule, + val_dataloader: DataLoader | None, +) -> float: + if runtime.is_main_process: + assert val_dataloader.batch_size is not None + model_device = next(stitched_model.parameters()).device + + with runtime.autocast(): + stitched_model.eval() + + all_logits: list[torch.Tensor] = [] + all_targets: list[torch.Tensor] = [] + losses: list[float] = [] + + if runtime.is_main_process: + input_ids: torch.Tensor + targets: torch.Tensor + + for i_batch, batch in enumerate(tqdm(val_dataloader)): + input_ids, targets = ( + batch["input_ids"].to(model_device), + batch["targets"].to(model_device), + ) + + if i_batch == 0: + num_batches = len(val_dataloader) + seq_len = input_ids.shape[1] + if torch.distributed.is_initialized(): + torch.distributed.broadcast_object_list([(num_batches, seq_len)]) + + all_targets.append(targets.cpu()) + + output = stitched_model({}, {}, input_ids) + logits = output.captured_outputs.get("model_output") + logits = getattr(logits, "logits", logits) + + if logits is not None: + all_logits.append(logits.cpu()) + + del output, logits + + if len(all_targets) > 0: + distributed_send_obj(all_targets, dst=runtime.world_size - 1) + + else: + obj_list: list[tuple] = [None] + torch.distributed.broadcast_object_list(obj_list) + num_batches, seq_len = obj_list[0] + + fake_input_ids = fake_tensor(1, seq_len, dtype=runtime.dtype) + + for i in range(num_batches): + output = stitched_model({}, {}, fake_input_ids) + logits = output.captured_outputs.get("model_output") + logits = getattr(logits, "logits", logits) + if logits is not None: + all_logits.append(logits.cpu()) + del output, logits + + if len(all_targets) == 0 and runtime.global_rank == runtime.world_size - 1: + all_targets = distributed_recv_obj(src=0) + + torch.distributed.barrier() + + if len(all_logits) > 0: + for logits, targets in zip(all_logits, all_targets): + logits = logits.to("cuda") + targets = targets.to("cuda") + logit_losses = torch.nn.functional.cross_entropy( + logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" + ) + + mean_losses = logit_losses.cpu().mean(dim=-1) + losses.extend(mean_losses.tolist()) + + val_loss = mean(losses) + + if not runtime.is_main_process: + distributed_send_obj(val_loss, dst=0) + elif runtime.is_main_process: + val_loss = distributed_recv_obj() + else: + val_loss = float("nan") + + stitched_model.train() + + loss_list = [val_loss] + torch.distributed.broadcast_object_list(loss_list) + val_loss = loss_list[0] + + return val_loss + + +@torch.no_grad() +def validate_pipeline( + runtime: IRuntime, + stitched_model: StitchedModule, + model_config: DeciLMConfig, + val_dataloader: DataLoader, + iter_num: int | None = None, + max_iters: int | None = None, + model_name: str | None = None, + enable_print: bool = True, + enable_wandb_log: bool = False, + # pad_to_batchsize: bool = True, +) -> float: + if enable_print: + mprint("Validating ...") + + val_loss = validate_pipeline_inner( + runtime=runtime, + stitched_model=stitched_model, + val_dataloader=val_dataloader, + ) + + if runtime.is_main_process: + key = "val/loss" if model_name is None else f"val/{model_name}_loss" + if enable_print: + prefix = "" + if iter_num is not None: + prefix += f"iter {iter_num}" + if max_iters is not None: + prefix += f"/{max_iters}" + prefix += " - " + mprint(f"{prefix}{key}: {val_loss:.4f}") + if enable_wandb_log: + wandb.log({key: val_loss}, step=iter_num) + + runtime.wait_for_everyone() + + return val_loss + + +class HiddenStatesAndLMHead(list): + def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): + super().__init__(hidden_states) + self.lm_head_weights = lm_head_weights + + +@torch.no_grad() +def calculate_losses_pipeline( + runtime: IRuntime, + stitched_model: StitchedModule | DeciLMForCausalLM, + dataloader: DataLoader | None, + target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, + return_hidden_states: bool = False, + calculate_full_score_ablations: bool = False, + calc_on_cpu: bool = False, + just_model_forward: bool = False, + checkpoint_manager=None, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + """ + Do model forward on each batch and calculate LM loss. + Optionally also calculate kl_div loss and other metrics from given target_hidden_states_per_batch. + Optionally return hidden states per batch. + Does not support data-parallel. + just_model_forward: skip loss calculation, just forward the model. Useful for activation hooks. + + + Returns: + losses: dict = { + "lm_loss": { + "avg": float, + "per_sample": list[float] + } + more metrics if provided with target_hidden_states_per_batch + } + target_hidden_states_per_batch: list[torch.Tensor], returned if return_hidden_states=True + + """ + if isinstance(stitched_model, DeciLMForCausalLM): + stitched_model = perform_pipeline_stitches(stitched_model, runtime) + + params = list(stitched_model.parameters()) + model_device = params[0].device if params else "cpu" + + # Pre-populate outputs with dummy values for skipped batches + start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 + if runtime.is_last_process: + outputs = [{"lm_loss": [0.0]}] * start_batch + else: + outputs = None + + if runtime.is_main_process: + all_input_ids, all_targets = zip( + *[(batch["input_ids"], batch["targets"]) for batch in dataloader] + ) + if runtime.world_size > 1: + distributed_send_obj(all_targets, dst=runtime.world_size - 1) + + if runtime.is_last_process: + if runtime.world_size > 1: + all_targets = distributed_recv_obj(src=0) + + lm_head: LMHead = next( + module + for module_name, module in stitched_model.named_modules() + if "lm_head" in module_name + ) + + if target_hidden_states_per_batch is not None: + lm_head_weights = target_hidden_states_per_batch.lm_head_weights + with torch.device(model_device): + target_lm_head = init_module_with_state_dict( + {"weight": lm_head_weights}, LMHead, *lm_head_weights.shape[::-1], bias=False + ) + + if runtime.is_main_process: + num_batches = len(all_input_ids) + seq_len = all_input_ids[0].shape[1] + if runtime.world_size > 1: + torch.distributed.broadcast_object_list([num_batches, seq_len]) + + # Create progress bar with sliced range starting from checkpoint position + desc = ( + f"[rank {runtime.global_rank}] calculate_losses_pipeline(" + f"{(target_hidden_states_per_batch is None)=}, {return_hidden_states=}, {num_batches=})" + ) + progress_bar = tqdm(range(start_batch, num_batches), desc=desc) + else: + obj_list = [None, None] + if runtime.world_size > 1: + torch.distributed.broadcast_object_list(obj_list) + num_batches, seq_len = obj_list + progress_bar = range(start_batch, num_batches) + + stitched_model.eval() + + with runtime.autocast(): + for i_batch in progress_bar: + if runtime.is_main_process: + input_ids = all_input_ids[i_batch].to(model_device) + else: + input_ids = fake_tensor(1, seq_len, dtype=torch.long) + + output = stitched_model({}, {}, input_ids) + + if runtime.is_last_process: + logits = output.captured_outputs.get("model_output") + logits = getattr(logits, "logits", logits) + hidden_states = output.captured_outputs.get("hidden_states") + targets = all_targets[i_batch].to(model_device) + + target_hidden_states = None + target_logits = None + if target_hidden_states_per_batch is not None: + target_hidden_states = target_hidden_states_per_batch[i_batch] + target_hidden_states = target_hidden_states.to(hidden_states.device) + target_logits = target_lm_head(target_hidden_states) + + if just_model_forward: + batch_outputs = {"lm_loss": [-1.0] * len(targets)} + else: + batch_outputs = calculate_batch_outputs( + hidden_states, + target_hidden_states, + logits, + target_logits, + targets, + return_hidden_states, + calculate_full_score_ablations, + calc_on_cpu, + ) + + outputs.append(batch_outputs) + + # Update checkpoint progress periodically + if checkpoint_manager: + checkpoint_manager.update_progress(i_batch + 1, num_batches) + + losses, hidden_states_per_batch = ( + _organize_outputs(outputs) if outputs is not None else (None, None) + ) + + if hidden_states_per_batch is not None: + hidden_states_per_batch = HiddenStatesAndLMHead( + hidden_states_per_batch, lm_head.weight.cpu() + ) + + runtime.wait_for_everyone() + return losses, hidden_states_per_batch + + +def perform_pipeline_stitches( + model: DeciLMForCausalLM, + runtime: IRuntime, +) -> StitchedModule: + target = ModuleTarget("module", model) + stitcher = Needle() + + is_real_block = np.flatnonzero( + [not isinstance(block, DummyBlock) for block in model.model.layers] + ) + first_block, last_block = is_real_block.min(), is_real_block.max() + + if runtime.global_rank != 0: + # receive activations from previous rank + stitcher.stitch( + RemoteTarget(peer_rank=runtime.global_rank - 1).value( + name="activations", adapter=lambda x: InputArgs(x) + ), + target.input( + name=f"model.layers.{first_block}", + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + if not runtime.is_last_process: + # send activations to next rank + stitcher.stitch( + target.output(f"model.layers.{last_block}"), + RemoteTarget(peer_rank=runtime.global_rank + 1).value(name="activations"), + ) + else: + # register model output + stitcher.stitch( + target.output(name="lm_head"), + ExternalTarget().output("model_output"), + ) + stitcher.stitch( + target.output(name="model.norm"), + ExternalTarget().output("hidden_states"), + ) + + stitched_module = stitcher.knot(ignore_extra_overrides=True) + return stitched_module diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py new file mode 100644 index 0000000000..63c6642248 --- /dev/null +++ b/modelopt/torch/_compress/utils/validation.py @@ -0,0 +1,826 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model validation and loss calculation utilities for single-GPU and multi-GPU setups. + +Also provides helper functions for loss metrics, KL divergence, JS divergence, +and similarity losses for knowledge distillation. +""" + +# mypy: ignore-errors +import functools +import math +from enum import Enum +from statistics import mean + +import numpy as np +import torch +import torch.distributed +import torch.nn.functional as F +import wandb +from accelerate import Accelerator +from modelopt.torch._compress.tools import kd_model +from torch import nn +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper +from typing_extensions import Self +from modelopt.torch._compress.utils.data.dataloaders import create_padded_tensor + + +@torch.no_grad() +def _validate_single( + accelerator: Accelerator, + model: torch.nn.Module, + rope_cache: torch.Tensor | None, + val_dataloader: DataLoader, + pad_to_batchsize: bool = True, + compute_kl_div: bool = False, + varlen: bool = False, + concat_token_id: int | None = None, +) -> list[float]: + assert val_dataloader.batch_sampler.batch_size is not None + desired_batch_size = val_dataloader.batch_sampler.batch_size + + with accelerator.device, accelerator.autocast(): + model.eval() + + losses: list[float] = [] + + input_ids: torch.LongTensor + targets: torch.LongTensor + is_first_batch = True + for batch in tqdm(val_dataloader, disable=not accelerator.is_main_process): + if is_first_batch: + print( + f"First batch, device {accelerator.device}, input_ids: {batch['input_ids'][:4]}" + ) + is_first_batch = False + input_ids, targets = ( + batch["input_ids"].to(accelerator.device), + batch["targets"].to(accelerator.device), + ) + batch_size = input_ids.size(0) + + if pad_to_batchsize: + input_ids = create_padded_tensor( + input_ids, (desired_batch_size, *input_ids.shape[1:]) + ) + targets = create_padded_tensor(targets, (desired_batch_size, *targets.shape[1:])) + + if rope_cache is not None: + logits = model( + input_ids, rope_cache=rope_cache, varlen=varlen, concat_token_id=concat_token_id + ) + else: + logits = model(input_ids) + + if hasattr(logits, "logits"): # For HF models + logits = logits.logits + + if isinstance(logits, tuple): # For KD + logits, teacher_logits, kd_block_loss, kd_logits_loss = logits + + if compute_kl_div: + # assumes kd_logits_loss has entry for each batch item + batch_losses = kd_logits_loss[:batch_size] + else: + batch_losses = torch.nn.functional.cross_entropy( + logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" + )[:batch_size].mean(dim=-1) + + losses.extend(batch_losses.tolist()) + + model.train() + + return losses + + +@torch.no_grad() +def validate_parallel( + accelerator: Accelerator, + model: torch.nn.Module, + rope_cache: torch.Tensor | None, + val_dataloader: DataLoader, + pad_to_batchsize: bool = True, + compute_kl_div: bool = False, + varlen: bool = False, + concat_token_id: int | None = None, +) -> float: + losses = _validate_single( + accelerator=accelerator, + model=model, + rope_cache=rope_cache, + val_dataloader=val_dataloader, + pad_to_batchsize=pad_to_batchsize, + compute_kl_div=compute_kl_div, + varlen=varlen, + concat_token_id=concat_token_id, + ) + + results = [float("nan")] + if accelerator.is_main_process: + gathered_results = [[float("nan")]] * accelerator.num_processes + torch.distributed.gather_object(losses, gathered_results) + gathered_losses = [l for result in gathered_results for l in result] + results[0] = mean(gathered_losses) + else: + torch.distributed.gather_object(losses) + + torch.distributed.broadcast_object_list(results) + val_loss = results[0] + + return val_loss + + +@torch.no_grad() +def validate( + accelerator: Accelerator, + model: torch.nn.Module, + rope_cache: torch.Tensor | None, + val_dataloader: DataLoader, + iter_num: int | None = None, + max_iters: int | None = None, + model_name: str | None = None, + enable_print: bool = True, + enable_wandb_log: bool = False, + pad_to_batchsize: bool = True, + compute_kl_div: bool = False, + varlen: bool = False, + concat_token_id: int | None = None, +) -> float: + if enable_print: + accelerator.print("Validating ...") + + val_loss = validate_parallel( + accelerator=accelerator, + model=model, + rope_cache=rope_cache, + val_dataloader=val_dataloader, + pad_to_batchsize=pad_to_batchsize, + compute_kl_div=compute_kl_div, + varlen=varlen, + concat_token_id=concat_token_id, + ) + + if accelerator.is_main_process: + key = "val/loss" if model_name is None else f"val/{model_name}_loss" + if enable_print: + prefix = "" + if iter_num is not None: + prefix += f"iter {iter_num}" + if max_iters is not None: + prefix += f"/{max_iters}" + prefix += " - " + accelerator.print(f"{prefix}{key}: {val_loss:.4f}", show_delta=True) + if enable_wandb_log: + wandb.log({key: val_loss}, step=iter_num) + accelerator.wait_for_everyone() + + return val_loss + + +class UnshardedLowMemorySparseTensor: + def __init__(self, x: torch.Tensor): + inds_dtype = self._infer_inds_dtype(x) + x_sparse = x.to_sparse_coo() + self._values = x_sparse.values() + self._indices = x_sparse.indices().to(inds_dtype) + self._size = x_sparse.size() + + @staticmethod + def _infer_inds_dtype(x: torch.Tensor) -> torch.dtype: + max_dim = max(x.shape) + for inds_dtype in [torch.int16, torch.int32, torch.int64]: + if torch.iinfo(inds_dtype).max >= max_dim: + return inds_dtype + + def to_sparse_coo(self) -> torch.Tensor: + return torch.sparse_coo_tensor(values=self._values, indices=self._indices, size=self._size) + + def to_dense(self) -> torch.Tensor: + return self.to_sparse_coo().to_dense() + + def to(self, *args) -> Self: + self._values = self._values.to(*args) + for arg in args: + if isinstance(arg, torch.device) or isinstance(arg, str): + self._indices = self._indices.to(arg) + return self + + +class LowMemorySparseTensor: + _max_sparse_size = torch.iinfo(torch.int32).max + + def __init__(self, x: torch.Tensor): + num_chunks = math.ceil(x.numel() / self._max_sparse_size) + self._chunk_dim = np.argmax(x.shape) + self._chunks = [ + UnshardedLowMemorySparseTensor(chunk) + for chunk in torch.chunk(x, num_chunks, dim=self._chunk_dim) + ] + + def to(self, *args) -> Self: + for chunk in self._chunks: + chunk.to(*args) + return self + + def to_dense(self) -> torch.Tensor: + return torch.concat([chunk.to_dense() for chunk in self._chunks], dim=self._chunk_dim) + + +@torch.no_grad() +def calculate_losses( + model: nn.Module, + dataloader: DataLoader, + target_probs: None = None, + return_probs: bool = False, + checkpoint_manager=None, +) -> tuple[dict[str, dict], None] | tuple[None, None]: + """ + Do model forward on each batch and calculate LM loss. + Works on lit-llama models (single gpu) and huggingface models (can be multi gpu). + Does not support data-parallel. + + ### Anything related to probs and hidden states is not supported currently! ### + calculate_losses() isn't updated according to the major refactor in + calculate_losses_pipeline() regarding hidden states. + + Returns: + outputs = { + "lm_loss": list[float], + "token_accuracy_top_1": list[float], + "token_accuracy_top_5": list[float], + "token_accuracy_top_10": list[float], + } + """ + if (target_probs is not None) or return_probs: + raise NotImplementedError( + "calculate_losses() isn't updated according to the major refactor in " + "calculate_losses_pipeline() regarding hidden states." + ) + + model_device = next(model.parameters()).device + outputs = [] + + try: + num_batches = len(dataloader) + except: + num_batches = None + + # Adjust progress bar for resume + start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 + progress_bar = tqdm( + enumerate(dataloader), + total=num_batches, + desc=f"calculate_losses({(target_probs is None)=}, {return_probs=})", + ) + if start_batch > 0: + progress_bar.update(start_batch) + + for i_batch, batch in progress_bar: + # Skip batch if resuming from checkpoint + if checkpoint_manager and checkpoint_manager.should_skip_batch(i_batch): + continue + + input_ids = batch["input_ids"].to(model_device) + logits = model(input_ids) + if hasattr(logits, "logits"): + logits = logits.logits + # logits = logits.float() + + targets = batch["targets"].to(model_device) + + batch_outputs = calculate_batch_outputs( + hidden_states=None, + target_hidden_states=None, + logits=logits, + target_logits=None, + targets=targets, + return_hidden_states=False, + calculate_full_score_ablations=False, + calc_on_cpu=False, + ) + outputs.append(batch_outputs) + + # Update checkpoint progress periodically + if checkpoint_manager: + checkpoint_manager.update_progress(i_batch + 1, num_batches) + + losses, _ = _organize_outputs(outputs) + return losses, None + + +def calc_entropy(logits: torch.Tensor) -> torch.Tensor: + """ + Returns per-token entropy given a logits tensor of shape [batch_size x seq_len x vocab_size]. + The output will have shape [batch_size x seq_len]. + """ + # Convert logits to log-probabilities + log_probs = F.log_softmax(logits, dim=-1) # shape: [B x T x V] + + # Compute probabilities from log-probabilities + probs = torch.exp(log_probs) # shape: [B x T x V] + + # Entropy calculation: sum over V of (- p * log p) + ent = -torch.sum(probs * log_probs, dim=-1) # shape: [B x T] + + return ent + + +def confidence_max_softmax(logits: torch.Tensor) -> torch.Tensor: + """ + Returns per-token max-softmax confidence given a logits tensor of shape [batch_size x seq_len x vocab_size]. + The output will have shape [batch_size x seq_len]. + """ + # Compute softmax probabilities + probs = F.softmax(logits, dim=-1) # shape: [B x T x V] + + # Take the maximum probability along the vocabulary dimension + max_confidence = torch.max(probs, dim=-1).values # shape: [B x T] + + return max_confidence + + +def calculate_batch_outputs( + hidden_states: torch.Tensor | None, + target_hidden_states: torch.Tensor | None, + logits: torch.Tensor, + target_logits: torch.Tensor | None, + targets: torch.Tensor, + return_hidden_states: bool, + calculate_full_score_ablations: bool, + calc_on_cpu: bool, +) -> dict: + if calc_on_cpu: + if hidden_states is not None: + hidden_states = hidden_states.cpu() + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.cpu() + if logits is not None: + logits = logits.cpu() + if target_logits is not None: + target_logits = target_logits.cpu() + if targets is not None: + targets = targets.cpu() + + batch_outputs = _calculate_ground_truth_based_scores(logits, targets) + + # _DEBUG_calculate_per_token_entropy(batch_outputs, logits) + + if (target_hidden_states is not None) or (target_logits is not None): + batch_outputs.update( + _calculate_teacher_similarity_scores( + hidden_states, + target_hidden_states, + logits, + target_logits, + calculate_full_score_ablations, + ) + ) + + if return_hidden_states: + batch_outputs["hidden_states_per_batch"] = hidden_states.cpu() + + return batch_outputs + + +def _DEBUG_calculate_per_token_entropy(batch_outputs, logits, i_batch): + import os + + # calculate the per token entropy and per token top p + entropy = calc_entropy(logits).cpu() # .view(-1)#.tolist() + msftm = confidence_max_softmax(logits).cpu() # .view(-1)#.tolist() + teacher_dir = ".../meta-llama/Meta-Llama-3.1-70B-Instruct-new_rope/" + file_path = f"{teacher_dir}/validation/per_token_stats_{i_batch}.pth" + os.makedirs(os.path.dirname(file_path), exist_ok=True) + torch.save({"entropy": entropy, "max_softmax": msftm}, file_path) + batch_outputs["entropy"] = entropy + batch_outputs["max_softmax"] = msftm + + +def _organize_outputs( + outputs_per_batch: list[dict], +) -> tuple[dict[str, dict], list[torch.Tensor] | None]: + outputs = _concatenate_batch_outputs(outputs_per_batch) + hidden_states_per_batch = outputs.pop("hidden_states_per_batch", None) + losses = { + loss_name: { + "avg": sum(loss_per_sample) / len(loss_per_sample), + "per_sample": loss_per_sample, + } + for loss_name, loss_per_sample in outputs.items() + } + return losses, hidden_states_per_batch + + +def _concatenate_batch_outputs(outputs_per_batch: list[dict]) -> dict[str, list]: + outputs = {} + for output_name in outputs_per_batch[0]: # Regular dict is directly iterable + item_list = [] + for batch_outputs in outputs_per_batch: + batch_items = batch_outputs[output_name] + if isinstance(batch_items, list | tuple): + item_list.extend(batch_items) + else: + item_list.append(batch_items) + outputs[output_name] = item_list + return outputs + + +def _calculate_per_sample_lm_loss( + logits: torch.Tensor, + targets: torch.Tensor, +) -> list[float]: + per_sample_lm_loss = ( + torch.nn.functional.cross_entropy( + logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" + ) + .mean(dim=-1) + .tolist() + ) + return per_sample_lm_loss + + +def _calculate_ground_truth_based_scores( + logits: torch.Tensor, + targets: torch.Tensor, +) -> dict[str, list[float]]: + scores = {"lm_loss": _calculate_per_sample_lm_loss(logits, targets)} + + for top_k in (1, 5, 10): + top_k_predictions = logits.topk(top_k, dim=-1).indices # [b, t, top_k] + is_target_in_predictions = (targets.unsqueeze(-1) == top_k_predictions).any( + dim=-1 + ) # [b, t] + fraction_model_predicted_target = is_target_in_predictions.float().mean(dim=-1) # [b] + scores[f"token_accuracy_top_{top_k}"] = fraction_model_predicted_target.tolist() + + return scores + + +def _calculate_per_sample_kl_div_loss( + logits: torch.Tensor, + batch_target_probs: torch.Tensor | LowMemorySparseTensor, +) -> list[float]: + if isinstance(batch_target_probs, LowMemorySparseTensor): + logits = top_p_top_k(logits) + curr_target_probs = batch_target_probs.to_dense().to(logits.device) # .float() + per_sample_kl_div = [ + F.kl_div( + logits[i_sample].log_softmax(-1), + curr_target_probs[i_sample], + reduction="none", + log_target=False, + ) + .sum(-1) + .mean(-1) + .item() + for i_sample in range(logits.shape[0]) + ] + return per_sample_kl_div + + +def cosine_embedding_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return kd_model.cosine_embedding_loss_batched(hidden_states, target_hidden_states).tolist() + + +def normalized_mse_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + kd_model.normalized_mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def mse_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + F.mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def mae_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + F.l1_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def _calculate_teacher_similarity_scores( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, + logits: torch.Tensor, + target_logits: torch.Tensor, + calculate_full_score_ablations: bool, +) -> dict[str, list[float]]: + """ + hidden_states: [batch, tokens, n_embd] + target_hidden_states: [batch, tokens, n_embd] + logits: [batch, tokens, vocab] + target_logits: [batch, tokens, vocab] + """ + + def calc_per_sample(func, logits, target_probs): + return [ + func(logits=logits[i_sample], target_probs=target_probs[i_sample]) + for i_sample in range(logits.shape[0]) + ] + + score_ablations = {} + + if (target_hidden_states is not None) and (hidden_states.shape == target_hidden_states.shape): + for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss): + score_name = f"{func.__name__}_hidden_states" + score_ablations[score_name] = func(hidden_states, target_hidden_states) + + if target_logits is not None: + for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss): + score_name = f"{func.__name__}_logits" + score_ablations[score_name] = func(logits, target_logits) + + for top_p in (0.99, 0.95, None) if calculate_full_score_ablations else (None,): + transformed_logits = ( + logits if (top_p is None) else top_p_top_k(logits, top_p=top_p, top_k=None) + ) + transformed_target_logits = ( + target_logits + if (top_p is None) + else top_p_top_k(target_logits, top_p=top_p, top_k=None) + ) + target_probs = transformed_target_logits.softmax(-1) + + for func in (kl_div, js_div, tv_dist): + for clip_epsilon in ( + ( + ClipEpsilon.NO_CLIP, + ClipEpsilon.CLIP_NO_RENORMALIZE, + ClipEpsilon.CLIP_RENORMALIZE, + ) + if calculate_full_score_ablations + else (ClipEpsilon.NO_CLIP,) + ): + epsilon_factors = ( + (1.0, 0.1, 0.01) if not clip_epsilon == ClipEpsilon.NO_CLIP else (None,) + ) + + for epsilon_factor in epsilon_factors: + score_name = ( + f"{func.__name__}--top_p_{top_p}--clip_epsilon_{clip_epsilon.name}" + f"--epsilon_factor_{epsilon_factor}" + ) + func_with_args = functools.partial( + func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor + ) + score_ablations[score_name] = calc_per_sample( + func_with_args, transformed_logits, target_probs + ) + if (top_p is None) and (clip_epsilon == ClipEpsilon.NO_CLIP): + short_score_name = func.__name__ + score_ablations[short_score_name] = score_ablations[score_name] + + for top_k in (1, 5, 10): + teacher_greedy_prediction = target_logits.argmax(dim=-1, keepdim=True) # [b,t,1] + student_top_k_predictions = logits.topk(top_k, dim=-1).indices # [b,t,k] + is_teacher_prediction_in_student_predictions = ( + teacher_greedy_prediction == student_top_k_predictions + ).any(dim=-1) # [b,t] + fraction_student_predicted_teacher = ( + is_teacher_prediction_in_student_predictions.float().mean(dim=-1) + ) # [b] + score_ablations[f"greedy_teacher_prediction_in_student_top_{top_k}"] = ( + fraction_student_predicted_teacher.tolist() + ) + + if calculate_full_score_ablations: + for top_p in (0.99, 0.95, 0.50, None): + # student + transformed_logits = logits.clone() + + # teacher + transformed_target_logits = ( + target_logits.clone() + if (top_p is None) + else top_p_top_k(target_logits, top_p=top_p, top_k=None) + ) + + target_probs = transformed_target_logits.softmax(-1) + mask = transformed_target_logits == -1000 + if torch.any(mask): + transformed_logits[mask] = 0 + transformed_target_logits[mask] = 0 + target_probs[mask] = 0 + + for func in (mse_loss, mae_loss): + score_name = f"{func.__name__}_logits_top_p_{top_p}" + score_ablations[score_name] = func( + transformed_logits, transformed_target_logits + ) + + if top_p is not None and top_p > 0.9: + func = kl_div + clip_epsilon = ClipEpsilon.NO_CLIP + score_name = ( + f"{func.__name__}--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered" + ) + func_with_args = functools.partial( + func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor + ) + score_ablations[score_name] = calc_per_sample( + func_with_args, logits, target_probs + ) + # score_name = f"{func.__name__}_abs--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered" + # score_ablations[score_name] = [s.abs() for s in score_ablations[score_name]] + + return score_ablations + + +class ClipEpsilon(Enum): + NO_CLIP = "NO_CLIP" + CLIP_RENORMALIZE = "CLIP_RENORMALIZE" + CLIP_NO_RENORMALIZE = "CLIP_NO_RENORMALIZE" + + +def _logits_to_logprobs( + logits: torch.Tensor, clip_epsilon: ClipEpsilon, epsilon_factor: float +) -> torch.Tensor: + """ + logits: [tokens, vocab] + """ + logprobs = logits.log_softmax( + -1 + ) # must normalize logits before clipping otherwise log(1/voacb) means nothing + if clip_epsilon == ClipEpsilon.NO_CLIP: + return logprobs + vocab_size = logprobs.shape[-1] + epsilon = math.log(epsilon_factor * 1 / vocab_size) + logprobs = torch.clip(logprobs, min=epsilon) + if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE: + logprobs = logprobs.log_softmax( + -1 + ) # we do log_softmax again to retain legitimate distributions + return logprobs + + +def kl_div( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """ + Kullback-Leibler Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + num_tokens = logits.shape[0] + logprobs = _logits_to_logprobs(logits, clip_epsilon, epsilon_factor) + + _kl_div = ( + F.kl_div(logprobs, target_probs, reduction="sum", log_target=False).item() / num_tokens + ) + return _kl_div + + +def js_div( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """ + Jensen-Shannon Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + probs = logits.softmax(-1) + mixture_probs = (probs + target_probs) / 2 + mixture_logprobs = mixture_probs.log().clip(min=-1000) + + pred_kl_div = kl_div(mixture_logprobs, probs, clip_epsilon, epsilon_factor) + target_kl_div = kl_div(mixture_logprobs, target_probs, clip_epsilon, epsilon_factor) + _js_div = 0.5 * (pred_kl_div + target_kl_div) + return _js_div + + +def tv_dist( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """ + Total Variation Distance (L1-loss) for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + num_tokens, vocab_size = logits.shape + probs = logits.softmax(-1) + + if clip_epsilon != ClipEpsilon.NO_CLIP: + epsilon = epsilon_factor * 1 / vocab_size + probs = probs.clip(min=epsilon) + target_probs = target_probs.clip(min=epsilon) + if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE: + probs = probs / probs.sum(-1, keepdim=True) + target_probs = target_probs / target_probs.sum(-1, keepdim=True) + + _tv_dist = 0.5 * (probs - target_probs).abs().sum().item() / num_tokens + return _tv_dist + + +DEFAULT_TOP_P = 0.999 +# WestLake model: +# 700 = percentile 0.9 for top_p=0.99 +# 1700 = percentile 0.95 for top_p=0.99 and percentile 0.75 for top_p=0.999 +# For top_p=0.999 and top_k=1700 you take about 75 GB for 2048*8192 tokens +DEFAULT_TOP_K = 1000 + + +def calculate_sparse_probs( + logits: torch.Tensor, + top_p: float | None = DEFAULT_TOP_P, + top_k: int | None = DEFAULT_TOP_K, + verbose: bool = False, +) -> LowMemorySparseTensor: + warped_logits = top_p_top_k(logits, top_p, top_k) + probs = warped_logits.softmax(-1) + sparse_probs = LowMemorySparseTensor(probs) + if True: # Always calculate these metrics (was: if verbose or True:) + probs_unfiltered = logits.softmax(-1) + num_active_per_token = (warped_logits > -1000).sum(-1).float() + prob_density = torch.tensor( + [ + probs_unfiltered[i, j, warped_logits[i, j] > -1000].sum(-1).float() + for j in range(probs_unfiltered.shape[1]) + for i in range(probs_unfiltered.shape[0]) + ] + ) + + print(f""" + Sparsity: + {num_active_per_token.mean().item()=} + {num_active_per_token.quantile(0.25).item()=} + {num_active_per_token.quantile(0.5).item()=} + {num_active_per_token.quantile(0.75).item()=} + {num_active_per_token.quantile(0.9).item()=} + {num_active_per_token.quantile(0.95).item()=} + {num_active_per_token.max().item()=} + + {probs_unfiltered.shape=} + {prob_density.shape=} + {prob_density.mean().item()=} + {prob_density.quantile(0.25).item()=} + {prob_density.quantile(0.5).item()=} + {prob_density.quantile(0.75).item()=} + {prob_density.quantile(0.9).item()=} + {prob_density.quantile(0.95).item()=} + {prob_density.max().item()=} + """) + return sparse_probs + + +def top_p_top_k( + logits: torch.Tensor, + top_p: float | None = DEFAULT_TOP_P, + top_k: int | None = DEFAULT_TOP_K, + filter_value=-1000, +) -> torch.Tensor: + logit_warpers = [] + if top_p is not None: + logit_warpers.append(TopPLogitsWarper(top_p=top_p, filter_value=filter_value)) + if top_k is not None: + logit_warpers.append(TopKLogitsWarper(top_k=top_k, filter_value=filter_value)) + + warped_logits = [] + for sample_logits in logits: + for warper in logit_warpers: + sample_logits = warper(input_ids=None, scores=sample_logits) + warped_logits.append(sample_logits) + warped_logits = torch.stack(warped_logits) + + return warped_logits diff --git a/pyproject.toml b/pyproject.toml index 90fb8af48f..7694cbaf3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,9 +62,10 @@ extend-ignore = [ "__init__.py" = ["F401", "F403"] "examples/*" = ["D"] "tests/*" = ["B017", "D", "E402", "PT012"] -"*/_[a-zA-Z]*" = ["D"] # Private packages (_abc/*.py) or modules (_xyz.py) -"*.ipynb" = ["D", "E501"] # Ignore missing docstrings or line length for Jupyter notebooks -"modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style +"*/_[a-zA-Z]*" = ["D"] # Private packages (_abc/*.py) or modules (_xyz.py) +"*.ipynb" = ["D", "E501"] # Ignore missing docstrings or line length for Jupyter notebooks +"modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style +"modelopt/torch/_compress/*" = ["C4", "D", "E", "F", "FURB", "I", "ISC", "N", "PERF", "PGH", "PIE", "PLE", "PLR", "PT", "RUF", "SIM", "TC", "UP", "W"] # TODO:Disabled for now, will enable later, once all puzzletron code is migrated [tool.ruff.lint.pycodestyle] diff --git a/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml b/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml index 046ff51f65..178edb50d8 100644 --- a/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml +++ b/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml @@ -12,4 +12,4 @@ write_results: false calc_losses_on_cpu: false activations_log_dir: model_name_or_path: -load_dataset_fn: ${get_object:utils.data.dataloaders.load_from_disk_fn} +load_dataset_fn: ${get_object:modelopt.torch._compress.utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/gpu/torch/export/test_fsdp2_export.py b/tests/gpu/torch/export/test_fsdp2_export.py index 0c3496dec1..690b363fa3 100644 --- a/tests/gpu/torch/export/test_fsdp2_export.py +++ b/tests/gpu/torch/export/test_fsdp2_export.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import copy from functools import partial From 8c9cdd49b8a8c2c9ddb63d32e00d84ecfe4043da Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 26 Nov 2025 17:38:59 +0100 Subject: [PATCH 16/62] Add L2NormHook and use it in megatron.py (#599) ## What does this PR do? - Add L2NormHook and use it in megatron.py - Using L2NormHook removes code duplication between _DynamicSelfAttention and _DynamicMLP This is the first step towards reusing activation scores logic across Minitron and Puzzle. Next steps: - complete redesign of megatron.py - move other activation hooks logic to hooks.py - then combined those hooks.py with a similar hooks.py functoriality in puzzle (modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py) Questions: - why in the code before and after this redesign we store temp variables in two ways _register_temp_attribute and self.hook_handle)? ``` self._register_temp_attribute("_activation_hook", activation_hook) # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute? self.hook_handle = self.linear_fc2.register_forward_hook(activation_hook) ``` --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Daniel Korzekwa Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/nas/plugins/megatron.py | 96 ++++++---------- modelopt/torch/nas/plugins/megatron_hooks.py | 104 ++++++++++++++++++ .../test_mcore_gpt_minitron_pruning.py | 48 ++++++++ 3 files changed, 184 insertions(+), 64 deletions(-) create mode 100644 modelopt/torch/nas/plugins/megatron_hooks.py diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index be34a84aa2..9796c5289e 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -55,7 +55,6 @@ from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.transformer_layer import TransformerLayer -from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.opt.dynamic import DynamicModule from modelopt.torch.opt.hparam import HPType from modelopt.torch.opt.searcher import ConstraintsDict @@ -77,11 +76,12 @@ ConstraintsRes, ) from ..hparams.concat import build_concat_hp -from ..modules import _DynamicLayerNorm +from ..modules import DynamicModuleList, _DynamicLayerNorm from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices from ..registry import DMRegistry from ..search_space import SampleFunc from ..traced_hp import TracedHp +from .megatron_hooks import MegatronL2NormHook SUPPORTED_MODELS = {GPTModel: "megatron.core.models.gpt.GPTModel"} @@ -265,39 +265,19 @@ def _setup(self): # can be discarded. # This limitation might be fixed in OMNIML-180 (Flexible Importance Estimator) # where we separate the importance estimation from the dynamic module. - self._register_temp_attribute("_activations", None) - self.hook_handle = self.linear_fc2.register_forward_hook(self._linear_fc2_forward_hook) + max_ffn_size = int(self.get_hparam(self.hparam_name).max) # type: ignore[arg-type] + activation_hook = MegatronL2NormHook(max_size=max_ffn_size) + self._register_temp_attribute("_activation_hook", activation_hook) + # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute? + self.hook_handle = self.linear_fc2.register_forward_hook(activation_hook) ffn_hidden_size.register_importance(self._estimate_importance) - def _linear_fc2_forward_hook(self, module, input, output): - """Hook to collect activations for importance estimation. - - Activations are computed as mean over seq_len and then squared and summed over batch_size. - Later we take the square root of the sum to get the L2 norm. - """ - # Gather input [seq_len, batch_size, ffn_hidden_size] over all TP regions - # NOTE: This is not used at the moment since we restrict to TP=1 - input = gather_from_tensor_model_parallel_region(input[0]).detach() - if input.dim() == 2: - # For sparse experts, there is no batch dimension. - input = input[:, None, :] - # Dont aggregate activations from non-max subnets (e.g. from profiling) - if input.shape[-1] != self.get_hparam(self.hparam_name).max: - return - - input = input.to(torch.float32) # use full precision to avoid overflow - activations = input.abs().mean(dim=0) # [batch_size, ffn_hidden_size] - activations = activations.pow(2).sum(dim=0) # [ffn_hidden_size] - if self._activations is None: - self._activations = activations - else: - self._activations += activations - def _estimate_importance(self) -> TracedHp.Importance: """Return the activation magnitude-based importance of the ffn_hidden_size.""" - assert self._activations is not None, "No activations collected for importance estimation." - # Convert squared sum to L2 norm - return self._activations.pow(0.5) + assert self._activation_hook._activations is not None, ( + "No activations collected for importance estimation." + ) + return self._activation_hook.accumulate() def set_hidden_size_hp(self, hidden_size: TracedHp) -> None: """Set hidden size for shared expert.""" @@ -612,46 +592,26 @@ def _setup(self): ) # register importance estimator for linear_qkv.output_size and linear_proj.input_size - self._register_temp_attribute("_activations", None) - self.hook_handle = self.linear_proj.register_forward_hook(self._linear_proj_forward_hook) + num_heads_per_group_max = int(self.get_hparam("num_heads_per_group").max) # type: ignore[arg-type] + num_query_groups_max = int(self.get_hparam("num_query_groups").max) # type: ignore[arg-type] + max_size = num_heads_per_group_max * num_query_groups_max * self.config.kv_channels + activation_hook = MegatronL2NormHook(max_size=max_size) + self._register_temp_attribute("_activation_hook", activation_hook) + # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute? + self.hook_handle = self.linear_proj.register_forward_hook(activation_hook) # NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads, # otherwise we would only have aggregated importance of heads per group. # While enforcing order during `sort_parameters`, we dont check the shape of the slice_order num_heads_per_group.register_importance(self._estimate_all_head_importance) num_query_groups.register_importance(self._estimate_query_group_importance) - def _linear_proj_forward_hook(self, module, input, output): - """Hook to collect activations for importance estimation. - - Activations are computed as mean over seq_len and then squared and summed over batch_size. - Later we take the square root of the sum to get the L2 norm. - """ - # Gather input [seq_len, batch_size, query_projection_size] over all TP regions - # NOTE: This is not used at the moment since we restrict to TP=1 - input = gather_from_tensor_model_parallel_region(input[0]).detach() - - # Dont aggregate activations from non-max subnets (e.g. from profiling) - if ( - input.shape[-1] - != self.get_hparam("num_heads_per_group").max - * self.get_hparam("num_query_groups").max - * self.config.kv_channels - ): - return - - input = input.to(torch.float32) # use full precision to avoid overflow - activations = input.abs().mean(dim=0) - activations = activations.pow(2).sum(dim=0) # [query_projection_size] - if self._activations is None: - self._activations = activations - else: - self._activations += activations - def _estimate_all_head_importance(self) -> TracedHp.Importance: """Return the importance for num_attention_heads (num_heads_per_group * num_query_groups).""" - assert self._activations is not None, "No activations collected for importance estimation." + assert self._activation_hook._activations is not None, ( + "No activations collected for importance estimation." + ) # Convert squared sum to L2 norm - scores = self._activations.pow(0.5) + scores = self._activation_hook.accumulate() attn_head_importance = torch.linalg.vector_norm( scores.view( self.get_hparam("num_heads_per_group").max @@ -665,9 +625,11 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance: def _estimate_query_group_importance(self) -> TracedHp.Importance: """Return the importance of the ``num_query_groups`` hparam.""" - assert self._activations is not None, "No activations collected for importance estimation." + assert self._activation_hook._activations is not None, ( + "No activations collected for importance estimation." + ) # Convert squared sum to L2 norm - scores = self._activations.pow(0.5) + scores = self._activation_hook.accumulate() group_importance = torch.linalg.vector_norm( scores.view( self.get_hparam("num_heads_per_group").max, @@ -1594,8 +1556,11 @@ def get_activations_and_layer_scores( """Get the per-rank activations and layer scores from the module.""" local_activations = {} for n, m in self.named_modules(): + # TODO: Remove legacy _activations check once all modules use _activation_hook if hasattr(m, "_activations"): local_activations[n] = m._activations + elif hasattr(m, "_activation_hook"): + local_activations[n] = m._activation_hook._activations activations_per_rank = dist.allgather( local_activations, group=get_pipeline_model_parallel_group() ) @@ -1624,8 +1589,11 @@ def set_activations_and_layer_scores( for layer in self.decoder.layers: layer._scores = layer_scores[layer.layer_number] for n, m in self.named_modules(): + # TODO: Remove legacy _activations check once all modules use _activation_hook if hasattr(m, "_activations"): m._activations = activations_per_rank[rank][n] + elif hasattr(m, "_activation_hook"): + m._activation_hook._activations = activations_per_rank[rank][n] def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None: diff --git a/modelopt/torch/nas/plugins/megatron_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks.py new file mode 100644 index 0000000000..833e03c042 --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Forward hooks for activation-based importance estimation (megatron NAS plugin).""" + +from abc import ABC, abstractmethod + +import torch +from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region +from torch import nn + + +class ForwardHook(ABC): + """Base class for PyTorch forward hooks. + + This follows the PyTorch forward hook API where the second + parameter is 'args' (a tuple of positional arguments passed to forward()). + + Usage: + hook = MyHook() + module.register_forward_hook(hook) + """ + + @abstractmethod + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that is called after the module's forward pass. + + Args: + module: The module this hook is registered on + args: Tuple of positional arguments passed to module.forward() + output: The output from module.forward() + + Returns: + None (does not modify the output) + """ + ... + + +class MegatronL2NormHook(ForwardHook): + """Hook for accumulating activation statistics for importance estimation. + + Activations are computed as mean over seq_len and then squared and summed over batch_size. + In the accumulate() method we take the square root of the sum to get the L2 norm. + + Args: + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__(self, max_size: int | None = None): + """Initialize the L2NormHook.""" + self.max_size = max_size + self._activations: torch.Tensor | None = None + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Accumulate activation statistics from the forward pass.""" + # Gather input [seq_len, batch_size, hidden_size] over all TP regions + # NOTE: This is not used at the moment since we restrict to TP=1 + input_tensor = gather_from_tensor_model_parallel_region(args[0]).detach() + + if input_tensor.dim() == 2: + # For sparse experts, there is no batch dimension. + input_tensor = input_tensor[:, None, :] + + # Dont aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and input_tensor.shape[-1] != self.max_size: + return + + input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow + activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size] + activations = activations.pow(2).sum(dim=0) # [hidden_size] + + if self._activations is None: + self._activations = activations + else: + self._activations += activations + + def accumulate(self) -> torch.Tensor: + """Return the accumulated L2 norm of activations. + + Returns: + Tensor of accumulated scores, one per channel + + Raises: + AssertionError: If no activations have been collected yet + """ + assert self._activations is not None, "No activations collected for importance estimation." + # Convert squared sum to L2 norm + return self._activations.pow(0.5) diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 2aa67b4ec1..094fc015d0 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -87,10 +87,12 @@ def _get_model(initialize_megatron=True): normalization=normalization, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, + use_cpu_initialization=True, # Ensure deterministic weight init across CUDA versions ).cuda() return model model = _get_model() + sd = model.state_dict() def forward_loop(m): @@ -134,6 +136,52 @@ def forward_loop(m): assert pruning_scores["layer_scores"] assert pruning_scores["activations_per_rank"] + # TODO: Simplify it: this unit test is too long, + # hard to read (the same set of assertions across different test cases with if-else). + + assert len(pruning_scores["activations_per_rank"]) == 1 + rank_0_activations = pruning_scores["activations_per_rank"][0] + + # Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4) + if pruned_ffn_div == 4: + # Layer scores + assert pruning_scores["layer_scores"][1] == pytest.approx(2.0868452191352844, abs=1e-3) + assert pruning_scores["layer_scores"][2] == pytest.approx(1.7638601660728455, abs=1e-3) + + # Validate decoder.layers.0.mlp activations + mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"] + assert mlp_0_acts.min().item() == pytest.approx(0.0015609927941114, abs=1e-3) + assert mlp_0_acts.max().item() == pytest.approx(0.3844809532165527, abs=1e-3) + assert mlp_0_acts.mean().item() == pytest.approx(0.0629318505525589, abs=1e-3) + + # Validate decoder.layers.1.mlp activations + mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"] + assert mlp_1_acts.min().item() == pytest.approx(0.0001484956446802, abs=1e-3) + assert mlp_1_acts.max().item() == pytest.approx(0.7835369110107422, abs=1e-3) + assert mlp_1_acts.mean().item() == pytest.approx(0.0926810950040817, abs=1e-3) + + # Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2) + elif pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1: + # Layer scores + assert pruning_scores["layer_scores"][1] == pytest.approx(2.1415508985519409, abs=1e-3) + assert pruning_scores["layer_scores"][2] == pytest.approx(1.7198008894920349, abs=1e-3) + + # Validate decoder.layers.0.self_attention activations + assert "decoder.layers.0.self_attention" in rank_0_activations + attn_0_acts = rank_0_activations["decoder.layers.0.self_attention"] + assert attn_0_acts.shape == torch.Size([256]) + assert attn_0_acts.min().item() == pytest.approx(0.0409194342792034, abs=1e-3) + assert attn_0_acts.max().item() == pytest.approx(0.5261313319206238, abs=1e-3) + assert attn_0_acts.mean().item() == pytest.approx(0.1613342612981796, abs=1e-3) + + # Validate decoder.layers.1.self_attention activations + assert "decoder.layers.1.self_attention" in rank_0_activations + attn_1_acts = rank_0_activations["decoder.layers.1.self_attention"] + assert attn_1_acts.shape == torch.Size([256]) + assert attn_1_acts.min().item() == pytest.approx(0.1189328655600548, abs=1e-3) + assert attn_1_acts.max().item() == pytest.approx(1.3832759857177734, abs=1e-3) + assert attn_1_acts.mean().item() == pytest.approx(0.4782669544219971, abs=1e-3) + # Assert weights are pruned correctly for layer in model.decoder.layers: assert layer.mlp.linear_fc1.weight.shape == ( From 1f724665fd5a792f2dadb845ab17bafcd1d8be29 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 27 Nov 2025 09:03:13 +0100 Subject: [PATCH 17/62] Add pruning checkpoints for the compress algorithm (#607) ## What does this PR do? Add pruning checkpoints for the compress algorithm. --------- Signed-off-by: Daniel Korzekwa --- .../score_pruning_activations.py | 2 +- modelopt/torch/_compress/compress.py | 4 +- .../nas/plugins/compress_nas_plugin.py | 2 +- .../torch/_compress/pruning/pruning_ckpts.py | 351 ++++ .../torch/_compress/sewing_kit/__init__.py | 17 +- modelopt/torch/_compress/sewing_kit/core.py | 14 +- .../_compress/sewing_kit/passage/__init__.py | 11 +- .../_compress/sewing_kit/passage/core.py | 13 +- modelopt/torch/_compress/sewing_kit/utils.py | 14 +- .../tools/bypassed_training/child_init.py | 1641 +++++++++++++++++ .../init_child_from_parent.py | 266 +++ modelopt/torch/_compress/tools/kd_model.py | 4 +- .../tools/sharded_checkpoint_utils.py | 1 - .../torch/_compress/tools/validate_model.py | 14 +- .../_compress/utils/checkpoint_manager.py | 5 +- .../torch/_compress/utils/data/dataloaders.py | 3 +- .../torch/_compress/utils/data/dataset.py | 3 +- .../utils/validate_runtime_pipeline.py | 11 +- modelopt/torch/_compress/utils/validation.py | 3 +- pyproject.toml | 2 +- setup.py | 1 + 21 files changed, 2323 insertions(+), 59 deletions(-) create mode 100644 modelopt/torch/_compress/pruning/pruning_ckpts.py create mode 100644 modelopt/torch/_compress/tools/bypassed_training/child_init.py create mode 100644 modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py index ef1e6c2738..4a276e8e82 100644 --- a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -18,13 +18,13 @@ import hydra import torch from omegaconf import DictConfig -from modelopt.torch._compress.utils.parsing import format_global_config from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.validate_model import validate_model from modelopt.torch._compress.utils.dist_utils import is_distributed +from modelopt.torch._compress.utils.parsing import format_global_config def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 64e241d104..765e3d6d42 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -23,12 +23,12 @@ import build_library_and_stats import mip_and_realize_models import pruning_ckpts -import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations import scoring from omegaconf import DictConfig -from modelopt.torch._compress.tools.runtime import IRuntime +import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.runtime import IRuntime def compress( diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 84af06b137..8fbf7c7c47 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -25,11 +25,11 @@ import build_library_and_stats import mip_and_realize_models -import pruning_ckpts import scoring import torch from torch import nn +import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, diff --git a/modelopt/torch/_compress/pruning/pruning_ckpts.py b/modelopt/torch/_compress/pruning/pruning_ckpts.py new file mode 100644 index 0000000000..4a0e5c15cd --- /dev/null +++ b/modelopt/torch/_compress/pruning/pruning_ckpts.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for creating pruned model checkpoints. + +This module provides functions to generate pruned checkpoints by modifying model architectures +(FFN intermediate sizes, attention head groups, hidden dimensions) and initializing child pruned models +from parent checkpoints. +""" + +# mypy: ignore-errors +import json +import os +import time +from typing import Optional + +import hydra +from omegaconf import DictConfig + +from modelopt.torch._compress.tools.bypassed_training.child_init import ( + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, +) +from modelopt.torch._compress.tools.bypassed_training.init_child_from_parent import ( + init_child_from_parent, +) +from modelopt.torch._compress.tools.checkpoint_utils import load_model_config +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.validate_model import validate_model + + +def launch_ffn_intermediates_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + for intermediate_size in cfg.pruning.intermediate_size_list: + dirname = f"ffn_{intermediate_size}_attn_no_op" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process intermediate_size {intermediate_size} has already been pruned & saved") + continue + + mprint("Process intermediate_size {}".format(intermediate_size)) + + model_config_overrides_json = {"ffn": [{"intermediate_size": intermediate_size}]} + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_json=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + + mprint(f"=== COMPLETED FFN PRUNING FOR FFN INTERMEDIATE SIZE={intermediate_size} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_attn_groups_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + for n_heads_in_group in cfg.pruning.n_heads_in_group_list: + dirname = f"n_heads_in_group{n_heads_in_group}" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process n_heads_in_group {n_heads_in_group} has already been pruned & saved") + continue + + mprint("Process n_heads_in_group {}".format(n_heads_in_group)) + mprint(f"=== STARTING ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===") + + model_config_overrides_json = {"attention": [{"n_heads_in_group": n_heads_in_group}]} + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_json=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + + mprint(f"=== COMPLETED ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_hidden_dim_prune_ckpt(cfg: DictConfig): + """Launch hidden dimension pruning using channel importance ranking.""" + # Get channel importance results from the activations log directory + activations_log_dir = cfg.pruning.activations_log_dir + channel_importance_path = os.path.join(activations_log_dir, "channel_importance_results.json") + + if not os.path.exists(channel_importance_path): + raise FileNotFoundError( + f"Channel importance results not found at {channel_importance_path}. " + f"Make sure to run the activation collection step first." + ) + + # Load parent model config to get FFN configuration + parent_model_config = load_model_config(cfg.pruning.model_name_or_path) + parent_hidden_size = parent_model_config.hidden_size + + # Get teacher's FFN configuration + intermediate_sizes = [] + for block_config in parent_model_config.block_configs: + if block_config.ffn.intermediate_size is not None: + intermediate_sizes.append(block_config.ffn.intermediate_size) + else: + intermediate_sizes.append(None) + + mprint(f"Teacher config:") + mprint(f" - hidden_size: {parent_hidden_size}") + mprint(f" - intermediate_sizes: {intermediate_sizes}") + os.makedirs(os.path.join(cfg.puzzle_dir, "ckpts"), exist_ok=True) + + for hidden_size in cfg.pruning.hidden_size_list: + mprint(f"\n######################################################################") + mprint(f"Hidden Size = {hidden_size}") + mprint(f"######################################################################\n") + + mprint(f"Child config:") + mprint(f" - hidden_size: {hidden_size}") + + # Create model config overrides with proper FFN configuration + model_config_overrides_json = json.dumps( + { + "hidden_size": hidden_size, + "ffn": [ + { + "intermediate_size": intermediate_size, + } + for intermediate_size in intermediate_sizes + ], + } + ) + + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + dirname = f"hidden_size_{hidden_size}" + output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + + mprint(f"Creating checkpoint with hidden_size={hidden_size}") + mprint(f"Model config overrides: {model_config_overrides_json}") + + init_child_from_parent( + parent_checkpoint_dir=cfg.pruning.model_name_or_path, + model_config_overrides_json=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode(cfg.pruning.linear_init_mode), + hidden_size_init_mode=HiddenSizeInitMode(cfg.pruning.hidden_size_init_mode), + channel_importance_path=channel_importance_path, + ) + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + mprint(f"Created pruned checkpoint at: {output_dir}") + + +def launch_experts_prune_ckpt( + cfg: DictConfig, + max_save_workers: Optional[int] = None, + max_layer_workers: Optional[int] = None, + symlink_suffix: Optional[str] = None, +): + for num_experts in cfg.pruning.num_experts_to_keep_list: + dirname = f"num_experts_{num_experts}" + # Create symlink name with optional suffix + symlink_name = f"{dirname}_{symlink_suffix}" if symlink_suffix else dirname + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", symlink_name)): + mprint( + f"Process num_experts {num_experts} (symlink: {symlink_name}) has already been pruned & saved" + ) + continue + mprint(f"Process num_experts {num_experts}") + mprint(f"=== STARTING EXPERT PRUNING FOR num_experts={num_experts} ===") + model_config_overrides_json = {"ffn": [{"moe": {"num_local_experts": num_experts}}]} + + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_json=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, symlink_name)) + + mprint(f"=== COMPLETED EXPERT PRUNING FOR num_experts={num_experts} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_moe_ffn_intermediates_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + for intermediate_size in cfg.pruning.intermediate_size_list: + dirname = f"moe_ffn_{intermediate_size}_attn_no_op" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process intermediate_size {intermediate_size} has already been pruned & saved") + continue + + mprint("Process intermediate_size {}".format(intermediate_size)) + + model_config_overrides_json = { + "attention": [{"no_op": True, "llama4": None}], + "ffn": [{"moe": {"expert_intermediate_dim": intermediate_size}}], + } + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_json=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + os.symlink(output_dir, os.path.join(cfg.puzzle_dir, "ckpts", dirname)) + + mprint(f"=== COMPLETED MOE FFN PRUNING FOR FFN INTERMEDIATE SIZE={intermediate_size} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_prune_ckpt(cfg: DictConfig): + target_layer = cfg.pruning.activation_hooks_kwargs.target_layer + # I/O optimization settings - same as FFN pruning + max_save_workers = None # Will auto-calculate as min(CPU count, num files) + if "PRUNING_SAVE_WORKERS" in os.environ: + max_save_workers = int(os.environ["PRUNING_SAVE_WORKERS"]) + + # Layer workers now auto-calculate but can still be overridden + max_layer_workers = None # Will auto-calculate as min(CPU count, num layers) + if "PRUNING_LAYER_WORKERS" in os.environ: + max_layer_workers = int(os.environ["PRUNING_LAYER_WORKERS"]) + + # Log optimization settings (extracted from individual pruning methods) + mprint(f"Optimization Settings:") + mprint( + f" - I/O workers (max_workers): {'auto-calculate' if max_save_workers is None else max_save_workers}" + ) + mprint( + f" - Layer workers (max_layer_workers): {'auto-calculate' if max_layer_workers is None else max_layer_workers}" + ) + mprint(f" (Override with env vars: PRUNING_IO_WORKERS, PRUNING_LAYER_WORKERS)") + + if target_layer == "mlp.down_proj": + launch_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) + elif target_layer == "self_attn.o_proj": + launch_attn_groups_prune_ckpt(cfg, max_save_workers, max_layer_workers) + elif target_layer == "layernorm": + launch_hidden_dim_prune_ckpt(cfg) + elif target_layer == "router": + # Check if we should use symlink suffix for chained pruning + symlink_suffix = getattr(cfg.pruning, "symlink_suffix", None) + launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers, symlink_suffix) + elif target_layer == "regex:experts\.\d+\.down_proj$": + launch_moe_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) + else: + raise NotImplementedError( + f"checkpoint pruning is not currently supported for target layer: {target_layer}" + ) + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(cfg) + launch_prune_ckpt(cfg) + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/sewing_kit/__init__.py b/modelopt/torch/_compress/sewing_kit/__init__.py index 6df9f8afa8..c8f7ffa013 100644 --- a/modelopt/torch/_compress/sewing_kit/__init__.py +++ b/modelopt/torch/_compress/sewing_kit/__init__.py @@ -13,22 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # mypy: ignore-errors + from .core import ( - Needle, + CantResolveNodeDependenciesException, + ConstantTarget, + ExternalTarget, + FunctionTarget, + InputsLoopFoundException, KnotException, LoopFoundException, - InputsLoopFoundException, + ModuleTarget, MultipleExternalNodesException, + Needle, OnlyInternalNodesException, OutputsLoopFoundException, - ExternalTarget, - ModuleTarget, - ConstantTarget, - FunctionTarget, RemoteTarget, StitchedModule, StitchedModuleException, - CantResolveNodeDependenciesException, StitchedModuleOutput, ) -from .passage import always_false_predicate, always_true_predicate, InputArgs +from .passage import InputArgs, always_false_predicate, always_true_predicate diff --git a/modelopt/torch/_compress/sewing_kit/core.py b/modelopt/torch/_compress/sewing_kit/core.py index 550c1298ca..8f926954b5 100644 --- a/modelopt/torch/_compress/sewing_kit/core.py +++ b/modelopt/torch/_compress/sewing_kit/core.py @@ -14,11 +14,14 @@ # limitations under the License. # mypy: ignore-errors + from __future__ import annotations + from abc import ABC from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union + from typing_extensions import override try: @@ -30,19 +33,18 @@ import torch.distributed import torch.nn as nn -from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip from .passage import ( - Passage, InputArgs, OutputValue, - Predicate, - always_false_predicate, + Passage, PassageInputAdapter, - PassageOutputAdapter, PassageInputOverrides, + PassageOutputAdapter, PassageOutputOverrides, + Predicate, + always_false_predicate, ) - +from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip InputAdapter = Callable[[InputArgs], InputArgs] OutputAdapter = Callable[..., OutputValue] diff --git a/modelopt/torch/_compress/sewing_kit/passage/__init__.py b/modelopt/torch/_compress/sewing_kit/passage/__init__.py index 98dfc683b3..02f4c19eac 100644 --- a/modelopt/torch/_compress/sewing_kit/passage/__init__.py +++ b/modelopt/torch/_compress/sewing_kit/passage/__init__.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. + from .core import ( - Passage, - PassageOutput, InputArgs, OutputValue, - Predicate, + Passage, PassageInputAdapter, - PassageOutputAdapter, PassageInputOverrides, + PassageOutput, + PassageOutputAdapter, PassageOutputOverrides, - always_true_predicate, + Predicate, always_false_predicate, + always_true_predicate, ) diff --git a/modelopt/torch/_compress/sewing_kit/passage/core.py b/modelopt/torch/_compress/sewing_kit/passage/core.py index 4a66638aac..71164f061f 100644 --- a/modelopt/torch/_compress/sewing_kit/passage/core.py +++ b/modelopt/torch/_compress/sewing_kit/passage/core.py @@ -15,10 +15,9 @@ # mypy: ignore-errors from __future__ import annotations -import sys - -from collections.abc import Sequence, Callable +import sys +from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any, ContextManager, Iterable, Mapping, Optional, Union @@ -27,19 +26,19 @@ except ImportError: from typing_extensions import Self +import torch.nn as nn from typing_extensions import override -import torch.nn as nn +from ..common import logger from ..utils import ( ActivityContext, - has_fake_tensor, + dynamo_skip, fake_tensors, + has_fake_tensor, is_submodule_of, is_submodule_or_same, real_tensors, - dynamo_skip, ) -from ..common import logger @dataclass diff --git a/modelopt/torch/_compress/sewing_kit/utils.py b/modelopt/torch/_compress/sewing_kit/utils.py index ebe90b2a44..16fe1b3fd3 100644 --- a/modelopt/torch/_compress/sewing_kit/utils.py +++ b/modelopt/torch/_compress/sewing_kit/utils.py @@ -16,7 +16,7 @@ from __future__ import annotations import inspect -from collections.abc import Sequence, Mapping +from collections.abc import Mapping, Sequence from contextlib import contextmanager from typing import ( Any, @@ -31,17 +31,17 @@ cast, overload, ) -from typing_extensions import override + import torch -import torch.distributed -import torch._dynamo import torch._C -from torch import Tensor -import torch.utils._pytree as pytree +import torch._dynamo +import torch.distributed import torch.nn as nn import torch.nn.functional as F +import torch.utils._pytree as pytree +from torch import Tensor from torch._subclasses import FakeTensor, FakeTensorMode - +from typing_extensions import override Fn = TypeVar("Fn", bound=Callable) diff --git a/modelopt/torch/_compress/tools/bypassed_training/child_init.py b/modelopt/torch/_compress/tools/bypassed_training/child_init.py new file mode 100644 index 0000000000..d9ead79a1c --- /dev/null +++ b/modelopt/torch/_compress/tools/bypassed_training/child_init.py @@ -0,0 +1,1641 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""TODO Add description. Analyze this code, why is it so long and complex? Can it be simplified?""" + +import concurrent.futures +import dataclasses +import json +import os +import re +import time +from copy import deepcopy +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from typeguard import check_type + +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + SUBBLOCK_CLS_DICT, + BlockConfig, + _get_dataclass_type, + _is_dataclass_type, +) +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.tools.logger import aprint, mprint +from modelopt.torch._compress.tools.runtime import IRuntime + + +class GQAInitMode(Enum): + RandomKV = "RandomKV" + AverageKV = "AverageKV" + FirstKV = "FirstKV" + RandomBlock = "RandomBlock" + CopyAsIs = "CopyAsIs" + Degrouping = "Degrouping" + PruneKVHeads = "PruneKVHeads" + + +class MlpInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + CopyAsIs = "CopyAsIs" + PruneByActivationsLog = "PruneByActivationsLog" + ExpertRemoval = "ExpertRemoval" + ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + MoEChannelPruning = "MoEChannelPruning" + + +class LinearInitMode(Enum): + Random = "Random" + FromTeacher = "FromTeacher" + + +class HiddenSizeInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + PruneByChannelRanking = "PruneByChannelRanking" + CopyAsIs = "CopyAsIs" + + +IgnoreFn = Callable[[str], bool] + +default_ignore_fn: IgnoreFn = lambda _: False + + +class Printer: + @staticmethod + def print(s: str) -> None: + print(s) + + +def _process_single_layer( + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: DeciLMConfig, + new_config: DeciLMConfig, + gqa_init_mode: GQAInitMode, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + linear_init_mode: LinearInitMode, + ignored_keys: set, + keys: dict, + is_original_mha: bool, + head_size: int, + hidden_size: int, +) -> Tuple[Dict[str, torch.Tensor], Dict[str, str]]: + """ + Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). + Thread-safe function for parallel layer processing. + """ + layer_out_state_dict = {} + keys_to_remove = {} + + parent_block_config = original_config.block_configs[layer_idx] + child_block_config = new_config.block_configs[layer_idx] + + # Attention processing + for part in ["weight", "bias"]: + attn_prefix = f"model.layers.{layer_idx}.self_attn" + q_key = f"{attn_prefix}.q_proj.{part}" + k_key = f"{attn_prefix}.k_proj.{part}" + v_key = f"{attn_prefix}.v_proj.{part}" + o_key = f"{attn_prefix}.o_proj.{part}" + attn_keys = [q_key, k_key, v_key, o_key] + # Drop attn keys that don't exist and required to be in the new state_dict + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] + if len(attn_keys) > 0 and all(key in keys for key in attn_keys): + for key in attn_keys: + keys_to_remove[key] = keys[key] + if all(key not in ignored_keys for key in attn_keys): + is_student_and_teacher_have_same_attention_implementation = all( + key in new_state_dict.keys() for key in attn_keys + ) + if is_student_and_teacher_have_same_attention_implementation: + if part == "weight": + wq, wk, wv, wo = _init_attention_weights( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk + layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo + else: + bias_sd = _init_attention_biases( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]): + if bias_key in bias_sd.keys(): + layer_out_state_dict[sd_key] = bias_sd[bias_key] + + else: + linear_attn_key = f"{attn_prefix}.linear_attn.weight" + is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict.keys() + if is_student_attn_replaced_with_linear: + if linear_init_mode == LinearInitMode.Random: + layer_out_state_dict[linear_attn_key] = new_state_dict[linear_attn_key] + elif linear_init_mode == LinearInitMode.FromTeacher: + layer_out_state_dict[linear_attn_key] = _init_linear_attn( + parent_state_dict, original_config, layer_idx, v_key, o_key + ) + else: + raise ValueError(f"Unknown {linear_init_mode=}") + else: + # student attn random init + for new_key in new_state_dict.keys(): + if attn_prefix in new_key: + layer_out_state_dict[new_key] = new_state_dict[new_key] + + # MLP/MoE processing + is_parent_moe = parent_block_config.ffn.is_moe + if not is_parent_moe: # not MoE, init the MLP + mlp_prefix = f"model.layers.{layer_idx}.mlp" + linear_mlp_key = f"{mlp_prefix}.linear_mlp.weight" + + is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict.keys() + if is_student_mlp_replaced_with_linear: + if linear_init_mode == LinearInitMode.Random: + layer_out_state_dict[linear_mlp_key] = new_state_dict[linear_mlp_key] + elif linear_init_mode == LinearInitMode.FromTeacher: + teacher_mlp_state_dict = { + k.split(mlp_prefix + ".")[1]: v + for k, v in parent_state_dict.items() + if mlp_prefix in k + } + layer_out_state_dict[linear_mlp_key] = _init_linear_mlp(teacher_mlp_state_dict) + else: + raise ValueError(f"Unknown {linear_init_mode=}") + else: + layer_out_state_dict.update( + _init_mlp( + mlp_init_mode=mlp_init_mode, + layer_idx=layer_idx, + original_config=original_config, + mlp_init_config=mlp_init_config, + original_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + new_config=new_config, + keys=keys, + ignored_keys=ignored_keys, + ) + ) + else: + is_child_moe = child_block_config.ffn.is_moe + if is_child_moe: + parent_moe_config = original_config.block_configs[layer_idx].ffn.moe + child_moe_config = new_config.block_configs[layer_idx].ffn.moe + if parent_moe_config == child_moe_config: + pass # copy the MoE as is + elif mlp_init_mode == MlpInitMode.MoEChannelPruning: + for expert_idx in range(parent_moe_config.num_local_experts): + layer_out_state_dict.update( + _init_mlp( + mlp_init_mode=mlp_init_mode, + layer_idx=layer_idx, + original_config=original_config, + mlp_init_config=mlp_init_config, + original_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + new_config=new_config, + keys=keys, + ignored_keys=ignored_keys, + expert_idx=expert_idx, + ) + ) + + elif mlp_init_mode == MlpInitMode.ExpertRemoval: # remove some of the routed experts + router_key, new_experts_keys = _generate_moe_keys( + layer_idx, child_block_config.ffn.moe.num_local_experts + ) + _, orig_experts_keys = _generate_moe_keys( + layer_idx, parent_block_config.ffn.moe.num_local_experts + ) + keys_to_remove[router_key] = keys.get(router_key) + for key in sum(orig_experts_keys.values(), []): + keys_to_remove[key] = keys.get(key) + + orig_experts_weights = { + name: [parent_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + new_experts_weights = { + name: [new_state_dict[key] for key in new_experts_module_keys] + for name, new_experts_module_keys in new_experts_keys.items() + } + out_router_weights, out_experts_weights = _init_moe_module( + layer_idx=layer_idx, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + orig_router_weight=parent_state_dict[router_key], + orig_experts_weights=orig_experts_weights, + new_router_weight=new_state_dict[router_key], + new_experts_weights=new_experts_weights, + ) + layer_out_state_dict[router_key] = out_router_weights + for name in new_experts_keys.keys(): + layer_out_state_dict.update( + zip(new_experts_keys[name], out_experts_weights[name]) + ) + elif child_block_config.ffn.no_op: # no-op, drop this layer + parent_mlp_prefix = f"model.layers.{layer_idx}.mlp" + for key in list(keys.keys()): + if key.startswith(parent_mlp_prefix): + keys_to_remove[key] = keys[key] + else: + assert mlp_init_mode == MlpInitMode.ConcatExpertsIntoDenseFFN, ( + "The parent layer is MoE and the child layer is a normal FFN. The only supported mode is ConcatExpertsAsMLP." + ) + + child_ffn_state_dict = _concatenate_experts_into_dense_ffn( + parent_state_dict, + mlp_init_config, + hidden_size, + layer_idx, + child_block_config, + parent_block_config, + ) + layer_out_state_dict.update(child_ffn_state_dict) + + for key in list(keys.keys()): + if key.startswith(f"model.layers.{layer_idx}.mlp"): + keys_to_remove[key] = keys[key] + + # Handle missing keys + for key_possibly_missing_in_student in [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.up_proj", + "mlp.down_proj", + "input_layernorm", + "post_attention_layernorm", + ]: + key_possibly_missing_in_student = f".{layer_idx}.{key_possibly_missing_in_student}" + is_key_missing_from_student = ( + len([k for k in new_state_dict.keys() if key_possibly_missing_in_student in k]) == 0 + ) + if is_key_missing_from_student: + for k in list(keys.keys()): + if key_possibly_missing_in_student in k: + keys_to_remove[k] = keys[k] + + return layer_out_state_dict, keys_to_remove + + +@torch.no_grad() +def create_child_state_dict( + original_state_dict: dict, + new_state_dict: dict, + original_config: DeciLMConfig, + new_config: DeciLMConfig, + gqa_init_mode: GQAInitMode, + ignore_fn: IgnoreFn = default_ignore_fn, + runtime: Optional[IRuntime] = Printer, + mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, + mlp_init_config: Optional[dict[str, Any]] = None, + owned_block_indexes: Optional[set[int]] = None, + linear_init_mode: LinearInitMode = LinearInitMode.Random, + hidden_size_init_mode: HiddenSizeInitMode = HiddenSizeInitMode.CopyAsIs, + channel_importance_path: Optional[str] = None, + max_layer_workers: Optional[int] = None, # Now optional - will auto-calculate if None +): + mprint("=== Starting create_child_state_dict with optimizations ===") + total_start_time = time.time() + + # Phase 1: Initial setup and validation + setup_start_time = time.time() + if owned_block_indexes is None: + owned_block_indexes = set(range(new_config.num_hidden_layers)) + + # Auto-calculate optimal layer workers: min(cpu_count, num_layers) + if max_layer_workers is None: + cpu_count = os.cpu_count() or 1 + num_layers = len(owned_block_indexes) + max_layer_workers = min(cpu_count, num_layers) + mprint( + f"Auto-calculated layer workers: min({cpu_count} CPUs, {num_layers} layers) = {max_layer_workers}" + ) + else: + mprint(f"Using specified layer workers: {max_layer_workers}") + + # Memory optimization: Pre-allocate output state dict with known shapes + expected_keys_and_shapes = {k: v.shape for k, v in new_state_dict.items()} + out_state_dict = {} + + # Pre-allocate tensors where possible to reduce memory fragmentation + for key, shape in expected_keys_and_shapes.items(): + if key in new_state_dict: + tensor = new_state_dict[key] + # Only make contiguous if necessary (memory optimization) + if not tensor.is_contiguous(): + out_state_dict[key] = tensor.contiguous() + else: + out_state_dict[key] = tensor + + original_n_heads_in_group_per_layer = [ + b.attention.n_heads_in_group for b in original_config.block_configs + ] + is_original_mha = set(original_n_heads_in_group_per_layer) == {1} + is_same_hidden_size = original_config.hidden_size == new_config.hidden_size + head_size = new_config.head_dim + orig_head_size = original_config.head_dim + assert head_size == orig_head_size, f"head_size {head_size} != orig_head_size {orig_head_size}" + + # Allow different hidden sizes for pruning + if not is_same_hidden_size: + assert new_config.hidden_size <= original_config.hidden_size, ( + f"New hidden size ({new_config.hidden_size}) must be <= original ({original_config.hidden_size})" + ) + assert hidden_size_init_mode != HiddenSizeInitMode.CopyAsIs, ( + "Cannot copy as is when hidden sizes differ" + ) + + hidden_size = original_config.hidden_size + + ignored_keys = set([key for key in original_state_dict.keys() if ignore_fn(key)]) + for key in ignored_keys: + aprint(f"Ignoring key {key} and taking its init from new_state_dict") + out_state_dict[key] = new_state_dict[key] + + keys = { + match.group(1) if (match := re.search(r"(h\.\d+\..*)", key)) is not None else key: key + for key in original_state_dict.keys() + } + setup_time = time.time() - setup_start_time + mprint(f"Phase 1 - Setup and memory pre-allocation: {setup_time:.2f}s") + + # Phase 2: Parallel layer processing + layer_processing_start_time = time.time() + + # Prepare arguments for parallel processing + process_layer_partial = partial( + _process_single_layer, + parent_state_dict=original_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + ) + + # Process layers in parallel with optimal worker count + mprint( + f"Processing {len(owned_block_indexes)} layers in parallel with {max_layer_workers} workers..." + ) + layer_results = [] + all_keys_to_remove = {} + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_layer_workers) as executor: + future_to_layer = { + executor.submit(process_layer_partial, layer_idx): layer_idx + for layer_idx in owned_block_indexes + } + + completed = 0 + for future in concurrent.futures.as_completed(future_to_layer): + layer_idx = future_to_layer[future] + try: + layer_state_dict, keys_to_remove = future.result() + layer_results.append((layer_idx, layer_state_dict)) + all_keys_to_remove.update(keys_to_remove) + + completed += 1 + if completed % 20 == 0 or completed == len( + owned_block_indexes + ): # More frequent progress updates + mprint(f"Completed {completed}/{len(owned_block_indexes)} layers") + except Exception as exc: + mprint(f"Layer {layer_idx} generated an exception: {exc}") + raise exc + + # Merge layer results into main state dict (memory efficient) + for layer_idx, layer_state_dict in layer_results: + out_state_dict.update(layer_state_dict) + + # Remove processed keys from the keys dict + for key_to_remove in all_keys_to_remove: + keys.pop(key_to_remove, None) + + layer_processing_time = time.time() - layer_processing_start_time + mprint( + f"Phase 2 - Parallel layer processing: {layer_processing_time:.2f}s ({max_layer_workers} workers)" + ) + + # Phase 3: Copy remaining keys from original model + copy_start_time = time.time() + keys_to_copy_from_orig_model = set(keys.values()) - ignored_keys + for key in keys_to_copy_from_orig_model: + aprint(f"copying {key} from original_state_dict") + # Memory optimization: avoid unnecessary copies + tensor = original_state_dict[key] + if not tensor.is_contiguous(): + out_state_dict[key] = tensor.contiguous() + else: + out_state_dict[key] = tensor + copy_time = time.time() - copy_start_time + mprint( + f"Phase 3 - Copy remaining keys: {copy_time:.2f}s ({len(keys_to_copy_from_orig_model)} keys)" + ) + + # Handle hidden size pruning for remaining keys + if not is_same_hidden_size: + out_state_dict = _apply_hidden_size_pruning( + out_state_dict, + original_state_dict, + new_config, + original_config, + hidden_size_init_mode, + channel_importance_path, + owned_block_indexes, + ) + + # Phase 4: Verification + verify_start_time = time.time() + _verify_state_dicts_match(out_state_dict, expected_keys_and_shapes) + verify_time = time.time() - verify_start_time + mprint(f"Phase 4 - Verification: {verify_time:.2f}s") + + total_time = time.time() - total_start_time + mprint(f"=== create_child_state_dict completed in {total_time:.2f}s ===") + mprint( + f"Breakdown: Setup {setup_time:.1f}s + ParallelProcessing {layer_processing_time:.1f}s + Copy {copy_time:.1f}s + Verify {verify_time:.1f}s" + ) + mprint( + f"Speedup: Used {max_layer_workers} workers for {len(owned_block_indexes)} layers (CPU utilization: {max_layer_workers}/{os.cpu_count() or 1})" + ) + + return out_state_dict + + +def _generate_moe_keys(layer_idx: int, num_experts: int) -> tuple[str, dict[str, list[str]]]: + mlp_prefix = f"model.layers.{layer_idx}.mlp" + router_key = f"{mlp_prefix}.router.weight" + names = ["gate_proj", "up_proj", "down_proj"] + experts_module_names = { + name: f"{mlp_prefix}.experts.{{expert_idx}}.{name}.weight" for name in names + } + return router_key, { + name: [module_name.format(expert_idx=expert_idx) for expert_idx in range(num_experts)] + for name, module_name in experts_module_names.items() + } + + +def _concatenate_experts_into_dense_ffn( + original_state_dict: dict[str, torch.Tensor], + mlp_init_config: Optional[dict], + hidden_size: int, + layer_idx: int, + child_block_config: BlockConfig, + parent_block_config: BlockConfig, +) -> dict[str, torch.Tensor]: + assert child_block_config.ffn.gated and child_block_config.ffn.hidden_act == "silu", ( + "Llama4 experts use SwiGLU." + ) + + # verify sizes + child_intermediate_size = child_block_config.ffn.intermediate_size + parent_moe_config = parent_block_config.ffn.moe + shared_expert_intermediate_dim = parent_moe_config.shared_expert_intermediate_dim + routed_expert_intermediate_dim = parent_moe_config.expert_intermediate_dim + total_concatenated_routed_experts_size = ( + child_intermediate_size - shared_expert_intermediate_dim + ) + assert total_concatenated_routed_experts_size % routed_expert_intermediate_dim == 0, ( + f"{child_intermediate_size=} " + f"{shared_expert_intermediate_dim=} " + f"{routed_expert_intermediate_dim=} " + f"{total_concatenated_routed_experts_size=} " + f"{total_concatenated_routed_experts_size % routed_expert_intermediate_dim=} != 0" + ) + num_concatenated_routed_experts = ( + total_concatenated_routed_experts_size // routed_expert_intermediate_dim + ) + + # if needed, concatenate some of the routed experts + if num_concatenated_routed_experts == 0: + print( + f"Removing all routed experts from layer {layer_idx}, turning the shared expert into a dense FFN." + ) + concat_routed_state_dict = dict() + else: + print( + f"Concatenating {num_concatenated_routed_experts} routed experts to the shared expert in layer {layer_idx}" + ) + router_key, orig_experts_keys = _generate_moe_keys( + layer_idx, parent_moe_config.num_local_experts + ) + orig_experts_weights = { + name: [original_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + _, experts_weights = _prune_experts_by_score( + mlp_init_config=mlp_init_config, + layer_idx=layer_idx, + orig_router_weight=original_state_dict[router_key], + orig_experts_weights=orig_experts_weights, + new_num_experts=num_concatenated_routed_experts, + ) + concat_dims = {"gate_proj": 0, "up_proj": 0, "down_proj": 1} + assert list(concat_dims) == list(experts_weights), ( + "concat_dims and experts_weights must have the same keys" + ) + concat_routed_state_dict = { + name: torch.cat(experts_weights[name], dim=concat_dims[name]) + for name in concat_dims.keys() + } + + # turn the shared expert into a normal FFN. concatenate the pruned routed experts if needed. + parent_shared_expert_prefix = f"model.layers.{layer_idx}.mlp.shared_expert" + child_ffn_prefix = f"model.layers.{layer_idx}.mlp" + child_ffn_state_dict = dict() + + for module_name in [ + "gate_proj", + "up_proj", + "down_proj", + ]: + shared_expert_key = f"{parent_shared_expert_prefix}.{module_name}.weight" + child_ffn_key = f"{child_ffn_prefix}.{module_name}.weight" + shared_expert_weight = original_state_dict[shared_expert_key] + concat_routed_weight = concat_routed_state_dict.get(module_name) + + if concat_routed_weight is None: + child_weight = shared_expert_weight + else: + child_weight = torch.cat( + [shared_expert_weight, concat_routed_weight], + dim=1 if module_name == "down_proj" else 0, + ) + child_ffn_state_dict[child_ffn_key] = child_weight + + return child_ffn_state_dict + + +def _verify_state_dicts_match( + state_dict: dict[str, torch.Tensor], + expected_keys_and_shapes: dict[str, torch.Size], +) -> None: + # Verify keys match + expected_keys = expected_keys_and_shapes.keys() + missing_keys = set(expected_keys) - set(state_dict.keys()) + unexpected_keys = set(state_dict.keys()) - set(expected_keys) + assert len(missing_keys) == 0 and len(unexpected_keys) == 0, ( + f"Missing keys: {missing_keys}\nUnexpected keys: {unexpected_keys}" + ) + + # Verify shapes match + shape_mismatches = [] + for key in expected_keys: + expected_shape = expected_keys_and_shapes[key] + actual_shape = state_dict[key].shape + if expected_shape != actual_shape: + shape_mismatches.append(f"{key}: expected {expected_shape}, got {actual_shape}") + + assert len(shape_mismatches) == 0, "Shape mismatches found:\n" + "\n".join(shape_mismatches) + print(""" +############################ +create_child_state_dict: all keys and shapes matched successfully. +############################ +""") + + +def _init_mlp( + *, + mlp_init_mode: Union[MlpInitMode, str], + layer_idx: int, + original_config: DeciLMConfig, + mlp_init_config: Optional[dict[str, Any]], + original_state_dict: dict, + new_state_dict: dict, + new_config: DeciLMConfig, + keys: dict[str, str], + ignored_keys: set[str], + expert_idx: Optional[int] = None, +) -> dict[str, torch.Tensor]: + out_state_dict = {} + + if mlp_init_mode == MlpInitMode.MoEChannelPruning: + if expert_idx is None: + return {} + mlp_prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}" + else: + mlp_prefix = f"model.layers.{layer_idx}.mlp" + + key = f"{mlp_prefix}.down_proj.weight" + if key not in keys: + return {} + + mlp_c_proj_key = keys[key] + if mlp_c_proj_key not in ignored_keys: + mlp_keys = [ + keys.pop(f"{mlp_prefix}.{module_name}.weight") + for module_name in ["down_proj", "gate_proj", "up_proj"] + ] + pruned_filters = None + projection_matrix = None + for mlp_key in mlp_keys: + expanded_dim = 1 if "down_proj" in mlp_key else 0 + if mlp_key in new_state_dict.keys(): + mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( + mlp_init_mode, + expanded_dim, + new_state_dict[mlp_key], + new_config, + original_state_dict[mlp_key], + original_config, + mlp_init_config, + pruned_filters, + projection_matrix, + mlp_prefix, + ) + out_state_dict[mlp_key] = mlp_module_weight + else: + mprint(f"mlp_key {mlp_key} not in new_state_dict") + return out_state_dict + + +def _init_mlp_module( + mlp_init_mode: Union[MlpInitMode, str], + expanded_dim: int, + new_item: torch.Tensor, + new_config: DeciLMConfig, + orig_item: torch.Tensor, + original_config: DeciLMConfig, + mlp_init_config: Optional[dict[str, Any]], + pruned_filters: Optional[torch.Tensor] = None, + projection_matrix: Optional[dict[str, torch.Tensor]] = None, + mlp_prefix: Optional[str] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + assert orig_item.ndim == 2, f"{orig_item.ndim=}" + assert new_item.ndim == 2, f"{new_item.ndim=}" + + assert new_config.num_hidden_layers == original_config.num_hidden_layers, ( + f"({new_config.num_hidden_layers=}) != ({original_config.num_hidden_layers=})" + ) + + orig_ffn_size = orig_item.shape[expanded_dim] + new_ffn_size = new_item.shape[expanded_dim] + + if mlp_init_mode == MlpInitMode.CopyAsIs: + assert new_ffn_size == orig_ffn_size, ( + f"({new_ffn_size=}) != ({orig_ffn_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + elif mlp_init_mode == MlpInitMode.Random: + mlp_module_weight = new_item + + elif new_ffn_size == orig_ffn_size: + mlp_module_weight = orig_item + + elif mlp_init_mode in ( + MlpInitMode.Truncate, + MlpInitMode.PruneByActivationsLog, + MlpInitMode.MoEChannelPruning, + ): + assert new_ffn_size <= orig_ffn_size, ( + f"({new_ffn_size=}) > ({orig_ffn_size=}), can't be truncated." + ) + + if mlp_init_mode == MlpInitMode.Truncate: + truncated_weight = torch.narrow( + orig_item, dim=expanded_dim, start=0, length=new_ffn_size + ) + mlp_module_weight = truncated_weight + + elif mlp_init_mode in (MlpInitMode.PruneByActivationsLog, MlpInitMode.MoEChannelPruning): + if pruned_filters is None: + filter_importance = _load_activations_log( + mlp_init_config, module_name=f"{mlp_prefix}.down_proj" + ) + filters_sorted_by_importance = torch.argsort(filter_importance, descending=True) + pruned_filters = filters_sorted_by_importance[:new_ffn_size].to(orig_item.device) + + pruned_weight = torch.index_select(orig_item, dim=expanded_dim, index=pruned_filters) + if mlp_init_config.get("scale_pruned_weights", False) and expanded_dim == 1: + pruned_weight = pruned_weight * (orig_ffn_size / new_ffn_size) + mlp_module_weight = pruned_weight + + elif ( + mlp_init_mode == MlpInitMode.ExpertRemoval + ): # the case of mlp layers of maverick. for now we only support copy as is + assert new_ffn_size == orig_ffn_size, ( + f"({new_ffn_size=}) != ({orig_ffn_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + else: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + return mlp_module_weight, pruned_filters, projection_matrix + + +def _init_moe_module( + *, + mlp_init_mode: Union[MlpInitMode, str], + mlp_init_config: Optional[dict[str, Any]], + layer_idx: int, + orig_router_weight: torch.Tensor, + orig_experts_weights: dict[str, list[torch.Tensor]], + new_router_weight: torch.Tensor, + new_experts_weights: dict[str, list[torch.Tensor]], +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + + if mlp_init_mode == MlpInitMode.ExpertRemoval: + result_router_weight, result_experts_weights = _prune_experts_by_score( + mlp_init_config=mlp_init_config, + layer_idx=layer_idx, + orig_router_weight=orig_router_weight, + orig_experts_weights=orig_experts_weights, + new_num_experts=new_router_weight.shape[0], + ) + else: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + assert result_router_weight.shape == new_router_weight.shape + assert result_experts_weights.keys() == new_experts_weights.keys(), ( + "result_experts_weights and new_experts_weights must have the same keys" + ) + assert all( + len(new_experts_weights[name]) == len(result_experts_weights[name]) + for name in result_experts_weights.keys() + ) + assert all( + all( + new_expert_weight.shape == result_expert_weight.shape + for new_expert_weight, result_expert_weight in zip( + new_experts_weights[name], result_experts_weights[name] + ) + ) + for name in result_experts_weights.keys() + ) + return result_router_weight, result_experts_weights + + +def _prune_experts_by_score( + *, + mlp_init_config: dict[str, Any], + layer_idx: int, + orig_router_weight: torch.Tensor, + orig_experts_weights: dict[str, list[torch.Tensor]], + new_num_experts: int, +) -> tuple[torch.Tensor, dict[str, list[torch.Tensor]]]: + orig_num_experts = orig_router_weight.shape[0] + assert all( + len(orig_experts_module_weights) == orig_num_experts + for orig_experts_module_weights in orig_experts_weights.values() + ) + expert_scores = _load_expert_scores(mlp_init_config)[layer_idx] + assert len(expert_scores) == orig_num_experts + selected_experts = sorted( + range(orig_num_experts), + key=lambda i: expert_scores[i], + reverse=mlp_init_config.get("higher_is_better", True), + )[:new_num_experts] + result_router_weight = orig_router_weight[selected_experts] + result_experts_weights = { + name: [orig_experts_module_weights[i] for i in selected_experts] + for name, orig_experts_module_weights in orig_experts_weights.items() + } + return result_router_weight, result_experts_weights + + +def _load_expert_scores(mlp_init_config: Optional[dict[str, Any]]) -> list[list[int | float]]: + assert mlp_init_config is not None + if "expert_scores_file" in mlp_init_config: + expert_scores_file = mlp_init_config["expert_scores_file"] + with open(expert_scores_file, "r") as f: + expert_scores = json.load(f) + elif "activations_log_dir" in mlp_init_config: + _cache_activations_log(mlp_init_config) + num_layers = len(ACTIVATIONS_LOG) + expert_scores = [] + for layer_idx in range(num_layers): + router_name = f"model.layers.{layer_idx}.mlp.router" + expert_scores.append(ACTIVATIONS_LOG[router_name]["expert_ranks"]) + expert_scores = torch.stack(expert_scores) + expert_scores = expert_scores.tolist() + else: + raise ValueError(f"Unsupported {mlp_init_config=}") + return expert_scores + + +ACTIVATIONS_LOG = dict() + + +def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None: + if len(ACTIVATIONS_LOG) == 0: + assert "activations_log_dir" in mlp_init_config + activations_log_dir = mlp_init_config["activations_log_dir"] + print(f"Loading activations_log from {activations_log_dir}") + ACTIVATIONS_LOG.update( + { + module_name: module_log + for p in Path(activations_log_dir).glob("rank*.pth") + for module_name, module_log in torch.load(p).items() + } + ) + + +def _load_activations_log(mlp_init_config: dict[str, Any], module_name: str) -> torch.Tensor: + _cache_activations_log(mlp_init_config) + module_log = ACTIVATIONS_LOG[module_name] + filter_importance = module_log["score"] + return filter_importance + + +def _init_attention_weights( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + n_heads_in_group = new_config.block_configs[layer_idx].attention.n_heads_in_group + orig_n_heads_in_group = original_config.block_configs[layer_idx].attention.n_heads_in_group + num_kv_heads = num_q_heads // n_heads_in_group + orig_num_kv_heads = num_q_heads // orig_n_heads_in_group + + # new_w* are typically randomly initialized + new_wq = new_state_dict[q_key] + new_wk = new_state_dict[k_key] + new_wv = new_state_dict[v_key] + new_wo = new_state_dict[o_key] + + # w* are from the parent model + wq = original_state_dict[q_key] + wk = original_state_dict[k_key] + wv = original_state_dict[v_key] + wo = original_state_dict[o_key] + + if "bias" in k_key: + for tensor in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo]: + assert tensor.ndim == 1 + tensor.unsqueeze_(1) + dim1 = wk.shape[1] # this is the hidden_size in case of matrix weights, and 1 in case of biases + + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock): + wk, wv = new_wk, new_wv + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV): + assert n_heads_in_group % orig_n_heads_in_group == 0, ( + f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" + ) + n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group + + wk = wk.view(-1, n_heads_to_aggregate, head_size, dim1) + wv = wv.view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + wk = wk.mean(dim=1) + wv = wv.mean(dim=1) + else: + wk = wk[:, 0] + wv = wv[:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + assert new_wk.shape == wk.shape, f"({new_wk.shape=}) != ({wk.shape=})" + assert new_wv.shape == wv.shape, f"({new_wv.shape=}) != ({wv.shape=})" + assert new_wq.shape == wq.shape, f"({new_wq.shape=}) != ({wq.shape=})" + assert new_wo.shape == wo.shape, f"({new_wo.shape=}) != ({wo.shape=})" + + elif gqa_init_mode == GQAInitMode.Degrouping: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = new_config.num_attention_heads // n_heads_in_group + orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + wk = degroup_w(wk) + wv = degroup_w(wv) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + wk = wk.view(orig_num_kv_heads, head_size, dim1) + wv = wv.view(orig_num_kv_heads, head_size, dim1) + wq = wq.view(orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1) + wo = wo.view(dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size) + + o_proj_module_name = o_key.replace(".weight", "") + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + wk = wk[kv_heads_to_keep] + wv = wv[kv_heads_to_keep] + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + wq = wq[kv_heads_to_keep] + wq = torch.repeat_interleave(wq, repeats=reduction_factor, dim=0) + + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + wo = wo[:, kv_heads_to_keep] + wo = torch.repeat_interleave(wo, repeats=reduction_factor, dim=1) + wo = wo / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + wq = wq[kv_head_ordering] + + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + wo = wo[:, kv_head_ordering] + wo[:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + wk = wk.reshape(-1, dim1) + wv = wv.reshape(-1, dim1) + wq = wq.reshape(-1, dim1) + wo = wo.reshape(dim1, -1) + return wq, wk, wv, wo + + +def _init_attention_biases( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config: DeciLMConfig, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + n_heads_in_group = new_config.block_configs[layer_idx].attention.n_heads_in_group + orig_n_heads_in_group = original_config.block_configs[layer_idx].attention.n_heads_in_group + num_kv_heads = num_q_heads // n_heads_in_group + orig_num_kv_heads = num_q_heads // orig_n_heads_in_group + + o_proj_bias = new_config.o_proj_bias + attention_bias = new_config.attention_bias + + # If no biases + if not (o_proj_bias or attention_bias): + return {} + + new_bias_sd = {} + bias_sd = {} + # new_w* are typically randomly initialized + if o_proj_bias: + new_bias_sd["o"] = new_state_dict[o_key] + bias_sd["o"] = original_state_dict[o_key] + if attention_bias: + for bias_key, key in zip("qkv", [q_key, k_key, v_key]): + new_bias_sd[bias_key] = new_state_dict[key] + bias_sd[bias_key] = original_state_dict[key] + + # maybe unsqueeze all tensors + for tensor in list(new_bias_sd.values()) + list(bias_sd.values()): + assert tensor.ndim == 1 + tensor.unsqueeze_(1) + + dim1 = 1 # this is the hidden_size in case of matrix weights, and 1 in case of biases + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock) and attention_bias: + bias_sd["k"] = torch.zeros( + new_bias_sd["k"].shape, dtype=bias_sd["k"].dtype, device=bias_sd["k"].device + ) + bias_sd["v"] = torch.zeros( + new_bias_sd["v"].shape, dtype=bias_sd["v"].dtype, device=bias_sd["v"].device + ) + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV) and attention_bias: + assert n_heads_in_group % orig_n_heads_in_group == 0, ( + f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" + ) + n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group + + bias_sd["k"] = bias_sd["k"].view(-1, n_heads_to_aggregate, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + bias_sd["k"] = bias_sd["k"].mean(dim=1) + bias_sd["v"] = bias_sd["v"].mean(dim=1) + else: + bias_sd["k"] = bias_sd["k"][:, 0] + bias_sd["v"] = bias_sd["v"][:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + for key in bias_sd.keys(): + assert new_bias_sd[key].shape == bias_sd[key].shape, ( + f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})" + ) + + elif gqa_init_mode == GQAInitMode.Degrouping and attention_bias: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = new_config.num_attention_heads // n_heads_in_group + orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + bias_sd["k"] = degroup_w(bias_sd["k"]) + bias_sd["v"] = degroup_w(bias_sd["v"]) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + if o_proj_bias: + o_proj_module_name = o_key.rsplit(".", 1)[0] + else: + # Here we assume that the o_proj layer is called "o_proj" + o_proj_module_name = k_key.rsplit(".", 2)[0] + ".o_proj" + + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + # view as KV groups + if attention_bias: + bias_sd["k"] = bias_sd["k"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["q"] = bias_sd["q"].view( + orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1 + ) + # Keep important KV heads and prune the others + bias_sd["k"] = bias_sd["k"][kv_heads_to_keep] + bias_sd["v"] = bias_sd["v"][kv_heads_to_keep] + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].view( + dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size + ) + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + if attention_bias: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + bias_sd["q"] = bias_sd["q"][kv_heads_to_keep] + bias_sd["q"] = torch.repeat_interleave( + bias_sd["q"], repeats=reduction_factor, dim=0 + ) + + if o_proj_bias: + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + bias_sd["o"] = bias_sd["o"][:, kv_heads_to_keep] + bias_sd["o"] = torch.repeat_interleave( + bias_sd["o"], repeats=reduction_factor, dim=1 + ) + bias_sd["o"] = bias_sd["o"] / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + if attention_bias: + bias_sd["q"] = bias_sd["q"][kv_head_ordering] + + if o_proj_bias: + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + bias_sd["o"] = bias_sd["o"][:, kv_head_ordering] + bias_sd["o"][:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + if attention_bias: + for bias_key in "qkv": + bias_sd[bias_key] = bias_sd[bias_key].reshape(-1) + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].reshape(-1) + return bias_sd + + +def _init_linear_attn( + parent_state_dict: dict[str, torch.Tensor], + parent_config: DeciLMConfig, + layer_idx: int, + v_key: str, + o_key: str, +) -> torch.Tensor: + """ + Init a linear layer that operates like an attention layer that assigns score 1 to the current token + and score 0 to all others: out = (Wo @ Wv) @ x + """ + n_embd = parent_config.hidden_size + head_size = parent_config.head_dim + n_heads_in_group = parent_config.block_configs[layer_idx].attention.n_heads_in_group + n_kv_heads = parent_config.num_attention_heads // n_heads_in_group + + wv = parent_state_dict[v_key] + wv = wv.view(n_kv_heads, head_size, n_embd) + wv_expanded = torch.repeat_interleave(wv, n_heads_in_group, dim=0).reshape(n_embd, n_embd) + + wo = parent_state_dict[o_key] + + w_linear = wo @ wv_expanded + return w_linear + + +def _init_linear_mlp(teacher_mlp_state_dict: dict[str, torch.Tensor]) -> torch.Tensor: + """ + A linear layer that does (W_down @ W_up) @ x, ignoring W_gate. + """ + if "linear_mlp.weight" in teacher_mlp_state_dict: # if the teacher itself is a linear layer + return teacher_mlp_state_dict["linear_mlp.weight"] + + w_up = teacher_mlp_state_dict["up_proj.weight"] + w_down = teacher_mlp_state_dict["down_proj.weight"] + w_linear = w_down @ w_up + return w_linear + + +def update_model_config( + model_config: DeciLMConfig, + model_config_overrides: None | list[dict[str, Any]] | str | dict | Path = None, +) -> DeciLMConfig: + new_model_config = deepcopy(model_config) + if model_config_overrides is None: + return new_model_config + + model_config_overrides = _parse_model_config_overrides( + model_config_overrides, model_config.num_hidden_layers + ) + + def override(item, item_overrides): + if item_overrides is None: + return item_overrides + if dataclasses.is_dataclass(item): + assert isinstance(item_overrides, dict) + return dataclass_override(item, item_overrides) + if isinstance(item, list): + assert isinstance(item_overrides, list) + return list_override(item, item_overrides) + return item_overrides + + def list_override(ls, ls_overrides: list): + assert len(ls) == len(ls_overrides) + return [override(item, item_overrides) for item, item_overrides in zip(ls, ls_overrides)] + + def dataclass_override(dc, dc_overrides: dict): + if not set(dc_overrides.keys()).issubset(dataclasses.asdict(dc).keys()): + raise ValueError( + f"Uknown overrides for dataclass {type(dc)}: {', '.join(set(dc_overrides.keys()) - dataclasses.asdict(dc).keys())}" + ) + field_types = {field.name: field.type for field in dataclasses.fields(dc)} + dc_changes = {} + for key, item_overrides in dc_overrides.items(): + previous_value, item_type = getattr(dc, key), field_types[key] + # if original block was no_op, we should not override it + if getattr(dc, "no_op", False): + return dc + + if previous_value is None and _is_dataclass_type(item_type): + new_value = _get_dataclass_type(item_type)(**item_overrides) + else: + new_value = override(previous_value, item_overrides) + check_type(new_value, item_type) + dc_changes[key] = new_value + return dataclasses.replace(dc, **dc_changes) + + new_model_config.block_configs = list_override( + new_model_config.block_configs, model_config_overrides + ) + + return new_model_config + + +def _parse_model_config_overrides( + model_config_overrides_json: str | dict | Path | list[dict], + n_layer: int, +) -> list[dict[str, Any]]: + """ + example model_config_overrides_json: + { + "attention": [{"n_heads_in_group": 2}], + "ffn": [{"intermediate_size": 14336}] + } + """ + if isinstance(model_config_overrides_json, list) and isinstance( + model_config_overrides_json[0], dict + ): + return model_config_overrides_json + + if isinstance(model_config_overrides_json, dict): + model_config_overrides_dict = model_config_overrides_json + else: + if os.path.exists( + model_config_overrides_json + ): # using os.path.exists, because Path.exists throws an exception on long strings + model_config_overrides_json = Path(model_config_overrides_json).read_text() + print(f"I'm json loadsing over here. {model_config_overrides_json=}") + model_config_overrides_dict = json.loads(model_config_overrides_json) + + # Sanity checks and conversion to list of dictionaries + layer_wise_overrides = [{} for _ in range(n_layer)] + for config_key, config_value in model_config_overrides_dict.items(): + assert config_key in SUBBLOCK_CLS_DICT, f"Unknown config key: {config_key}" + assert isinstance(config_value, list), ( + f"Expected a list for {config_key}, got {config_value}" + ) + assert len(config_value) == n_layer or len(config_value) == 1, ( + f"Number of elements in {config_key} must be 1 or equal to the number of layers in the model" + ) + + if len(config_value) == 1: + model_config_overrides_dict[config_key] = config_value * n_layer + + for layer_idx in range(n_layer): + layer_wise_overrides[layer_idx][config_key] = model_config_overrides_dict[config_key][ + layer_idx + ] + + return layer_wise_overrides + + +def _apply_hidden_size_pruning( + out_state_dict: dict[str, torch.Tensor], + original_state_dict: dict[str, torch.Tensor], + new_config: DeciLMConfig, + original_config: DeciLMConfig, + hidden_size_init_mode: HiddenSizeInitMode, + channel_importance_path: Optional[str] = None, + owned_block_indexes: Optional[list[int]] = None, +) -> dict[str, torch.Tensor]: + """ + Apply hidden size pruning to all layers that depend on hidden_size. + This includes embeddings, layer norms, and any linear layers that haven't been handled yet. + """ + if isinstance(hidden_size_init_mode, str): + hidden_size_init_mode = HiddenSizeInitMode(hidden_size_init_mode) + + original_hidden_size = original_config.hidden_size + new_hidden_size = new_config.hidden_size + + if hidden_size_init_mode == HiddenSizeInitMode.CopyAsIs: + return out_state_dict + + # Load channel ranking if needed + if hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: + if channel_importance_path is not None: + with open(channel_importance_path, "r") as f: + channel_ranking = json.load(f)["channel_importance_ranking"] + else: + raise ValueError( + "channel_ranking_path must be provided in hidden_size_init_config for PruneByChannelRanking mode" + ) + + # Handle embedding layer + embed_key = "model.embed_tokens.weight" + if embed_key in out_state_dict and embed_key in original_state_dict: + out_state_dict[embed_key] = _prune_hidden_size_dimension( + original_state_dict[embed_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + else: + raise ValueError( + f"Embed key {embed_key} not found in out_state_dict or original_state_dict" + ) + + # Handle final layer norm + norm_key = "model.norm.weight" + if norm_key in out_state_dict and norm_key in original_state_dict: + out_state_dict[norm_key] = _prune_hidden_size_dimension( + original_state_dict[norm_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=0, + ) + + # Handle LM head + lm_head_key = "lm_head.weight" + if lm_head_key in out_state_dict and lm_head_key in original_state_dict: + if out_state_dict[lm_head_key].shape[1] != new_hidden_size: + out_state_dict[lm_head_key] = _prune_hidden_size_dimension( + original_state_dict[lm_head_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + + for block_idx in owned_block_indexes: + if new_config.block_configs[block_idx].parallel_blocks is None: + key_prefix = f"model.layers.{block_idx}" + out_state_dict = _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + new_config.block_configs[block_idx], + key_prefix, + ) + else: + for internal_block_idx in range( + len(new_config.block_configs[block_idx].parallel_blocks) + ): + block_config = new_config.block_configs[block_idx].parallel_blocks[ + internal_block_idx + ] + key_prefix = f"model.layers.{block_idx}.parallel_blocks.{internal_block_idx}" + out_state_dict = _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + block_config, + key_prefix, + ) + return out_state_dict + + +def _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + block_config, + key_prefix, +): + for layer_norm in ["input_layernorm", "post_attention_layernorm"]: + for part in ["weight", "bias"]: + key = f"{key_prefix}.{layer_norm}.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=0, + ) + attn_prefix = f"{key_prefix}.self_attn" + if block_config.attention.replace_with_linear: + linear_attn_key = f"{attn_prefix}.linear_attn.weight" + for dim in [0, 1]: + out_state_dict[linear_attn_key] = _prune_hidden_size_dimension( + out_state_dict[linear_attn_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=dim, + ) + elif block_config.attention.is_mamba: + for proj in ["in", "out"]: + mamba_key = f"{attn_prefix}.mamba_mixer.{proj}_proj.weight" + out_state_dict[mamba_key] = _prune_hidden_size_dimension( + out_state_dict[mamba_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if proj == "in" else 0, + ) + else: + for k in "qkvo": + for part in ["weight", "bias"]: + if k in "qkv" and part == "bias": + continue + key = f"{attn_prefix}.{k}_proj.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if part == "weight" and k in "qkv" else 0, + ) + ffn_prefix = f"{key_prefix}.mlp" + if block_config.ffn.replace_with_linear: + linear_mlp_key = f"{ffn_prefix}.linear_mlp.weight" + for dim in [0, 1]: + out_state_dict[linear_mlp_key] = _prune_hidden_size_dimension( + out_state_dict[linear_mlp_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=dim, + ) + elif block_config.ffn.moe is not None: + router_key = f"{ffn_prefix}.router.weight" + out_state_dict[router_key] = _prune_hidden_size_dimension( + out_state_dict[router_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + _prune_hidden_size_dimension_mlp( + f"{ffn_prefix}.shared_expert", + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + ) + for expert_idx in range(block_config.ffn.moe.num_local_experts): + _prune_hidden_size_dimension_mlp( + f"{ffn_prefix}.experts.{expert_idx}", + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + ) + else: + _prune_hidden_size_dimension_mlp( + ffn_prefix, out_state_dict, new_hidden_size, hidden_size_init_mode, channel_ranking + ) + return out_state_dict + + +def _prune_hidden_size_dimension_mlp( + name_prefix, out_state_dict, new_hidden_size, hidden_size_init_mode, channel_ranking +): + for proj in ["gate_proj", "up_proj", "down_proj"]: + for part in ["weight", "bias"]: + if proj != "down_proj" and part == "bias": + continue + key = f"{name_prefix}.{proj}.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if part == "weight" and proj != "down_proj" else 0, + ) + + +def _prune_hidden_size_dimension( + original_tensor: torch.Tensor, + new_hidden_size: int, + hidden_size_init_mode: HiddenSizeInitMode, + channel_ranking: Optional[list[int]] = None, + dim: int = -1, +) -> torch.Tensor: + """ + Prune a tensor along the specified dimension to match the new hidden size. + """ + original_size = original_tensor.shape[dim] + + if hidden_size_init_mode == HiddenSizeInitMode.Random: + # Initialize with random weights + new_shape = list(original_tensor.shape) + new_shape[dim] = new_hidden_size + return torch.randn(new_shape, dtype=original_tensor.dtype, device=original_tensor.device) + + elif hidden_size_init_mode == HiddenSizeInitMode.Truncate: + # Simple truncation - take the first new_hidden_size elements + if dim == -1: + return original_tensor[..., :new_hidden_size] + elif dim == 0: + return original_tensor[:new_hidden_size, ...] + elif dim == 1: + return original_tensor[:, :new_hidden_size, ...] + else: + # Handle other dimensions + slices = [slice(None)] * original_tensor.ndim + slices[dim] = slice(new_hidden_size) + return original_tensor[tuple(slices)] + + elif hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: + if channel_ranking is None: + raise ValueError("Channel ranking must be provided for PruneByChannelRanking mode") + + # Use channel ranking to select the most important channels + if len(channel_ranking) < new_hidden_size: + raise ValueError( + f"Channel ranking has {len(channel_ranking)} channels but need {new_hidden_size}" + ) + + # Take the top new_hidden_size channels according to ranking + selected_channels = channel_ranking[:new_hidden_size] + + if dim == -1: + return original_tensor[..., selected_channels] + elif dim == 0: + return original_tensor[selected_channels, ...] + elif dim == 1: + return original_tensor[:, selected_channels, ...] + else: + # Handle other dimensions + slices = [slice(None)] * original_tensor.ndim + slices[dim] = selected_channels + return original_tensor[tuple(slices)] + + else: + raise ValueError(f"Unsupported hidden_size_init_mode: {hidden_size_init_mode}") diff --git a/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py new file mode 100644 index 0000000000..dbb4eac0c8 --- /dev/null +++ b/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py @@ -0,0 +1,266 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""TODO Add description""" + +import argparse +import json +import time +from typing import Optional + +import torch +import yaml + +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch._compress.tools.bypassed_training.child_init import ( + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + create_child_state_dict, + update_model_config, +) +from modelopt.torch._compress.tools.checkpoint_utils import ( + copy_tokenizer, + load_model_config, + load_state_dict, +) +from modelopt.torch._compress.tools.checkpoint_utils_hf import ( + _save_checkpoint, + copy_deci_lm_hf_code, +) +from modelopt.torch._compress.tools.logger import mprint + +""" + +Usage example - remove all/some routed experts: +=============================================== + +PARENT_DIR=".../meta-llama/Llama-4-Scout-17B-16E-Instruct--deci-hf" + +MLP_INIT_MODE="ConcatExpertsIntoDenseFFN" + +## remove all routed experts, turn the shared expert into a dense FFN +# OUTPUT_DIR="/.../micro_scout/Scout-remove-routed-experts" +# MODEL_CONFIG_OVERRIDES_JSON=' +# { +# "ffn": [ +# { +# "moe": null, +# "intermediate_size": 14336, +# "gated": true, +# "hidden_act": "silu" +# } +# ] +# } +# ' + +## concat the shared expert with one routed expert into a dense FFN +OUTPUT_DIR=".../scratch/micro_scout/Scout-ConcatExpertsIntoDenseFFN-concat-shared-and-3-routed" +MODEL_CONFIG_OVERRIDES_JSON=' +{ + "ffn": [ + { + "moe": null, + "intermediate_size": 14336, + "gated": true, + "hidden_act": "silu" + } + ] +} +' + +echo "" +echo "MODEL_CONFIG_OVERRIDES_JSON:" +echo "${MODEL_CONFIG_OVERRIDES_JSON}" + +python -m modelopt.torch._compress.tools.bypassed_training.init_child_from_parent \ + --parent_checkpoint_dir="$PARENT_DIR" \ + --model_config_overrides_json="$MODEL_CONFIG_OVERRIDES_JSON" \ + --output_checkpoint_dir="$OUTPUT_DIR" \ + --mlp_init_mode="$MLP_INIT_MODE" \ + --mlp_init_config_yaml="$MLP_INIT_CONFIG_YAML" +""" + + +def init_child_from_parent( + parent_checkpoint_dir: str, + model_config_overrides_json: str, + output_checkpoint_dir: str, + gqa_init_mode: GQAInitMode, + mlp_init_mode: MlpInitMode, + mlp_init_config_yaml: Optional[str], + linear_init_mode: LinearInitMode, + hidden_size_init_mode: Optional[HiddenSizeInitMode] = None, + channel_importance_path: Optional[str] = None, + max_workers: Optional[int] = None, # Auto-calculate optimal workers if None + max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None +) -> None: + """ + Init child models from parent models in the style of bypass training, + but without having to run the entire bypass pipeline. + + I/O Optimization Parameters: + - max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files)) + - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers)) + """ + assert ( + gqa_init_mode != GQAInitMode.RandomKV + and gqa_init_mode != GQAInitMode.RandomBlock + and mlp_init_mode != MlpInitMode.Random + and linear_init_mode != LinearInitMode.Random + ), ( + "We do not support random init of any subblock in this script to avoid initializing the student model" + ) + + copy_tokenizer(parent_checkpoint_dir, output_checkpoint_dir) + + parent_model_config = load_model_config(parent_checkpoint_dir) + parent_state_dict = load_state_dict(parent_checkpoint_dir) + + # Parse the model config overrides + if isinstance(model_config_overrides_json, str): + model_config_overrides_dict = json.loads(model_config_overrides_json) + else: + model_config_overrides_dict = model_config_overrides_json + + # Separate global config overrides from block-level overrides + global_config_overrides = {} + block_config_overrides = {} + + for key, value in model_config_overrides_dict.items(): + if key in ["hidden_size"]: + global_config_overrides[key] = value + else: + block_config_overrides[key] = value + + # Load child model config with global overrides + child_model_config = load_model_config( + checkpoint_dir=parent_checkpoint_dir, + model_config_overrides=global_config_overrides, + ignore_unexpected_config_keys=True, + ) + + # Apply block-level overrides if any + if block_config_overrides: + child_model_config = update_model_config( + model_config=child_model_config, + model_config_overrides=block_config_overrides, + ) + + with torch.device("meta"): + child_model = DeciLMForCausalLM(child_model_config) + child_state_dict_with_meta_tensors = child_model.state_dict() + + mlp_init_config = ( + yaml.safe_load(mlp_init_config_yaml) + if isinstance(mlp_init_config_yaml, str) is None + else mlp_init_config_yaml + ) + + # Profile create_child_state_dict with automatic layer parallelization + mprint("Starting create_child_state_dict...") + start_time = time.time() + child_state_dict = create_child_state_dict( + original_state_dict=parent_state_dict, + new_state_dict=child_state_dict_with_meta_tensors, + original_config=parent_model_config, + new_config=child_model_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs, + channel_importance_path=channel_importance_path, + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + create_child_state_dict_time = time.time() - start_time + mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds") + + # Profile _save_checkpoint with automatic I/O worker calculation + mprint("Starting _save_checkpoint...") + actual_io_workers = max_workers if max_workers else "auto" + mprint(f"I/O Settings: max_workers={actual_io_workers}") + start_time = time.time() + _save_checkpoint( + child_model_config, + child_state_dict, + output_checkpoint_dir, + max_workers=max_workers, # Will auto-calculate if None + ) + save_checkpoint_time = time.time() - start_time + mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds") + + copy_deci_lm_hf_code(output_checkpoint_dir) + + # Print profiling summary with actual worker counts used + total_core_time = create_child_state_dict_time + save_checkpoint_time + actual_layer_workers = max_layer_workers if max_layer_workers else "auto" + actual_io_workers = max_workers if max_workers else "auto" + mprint(f"\n=== PROFILING SUMMARY ===") + mprint( + f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)" + ) + mprint( + f"_save_checkpoint: {save_checkpoint_time:.2f}s ({save_checkpoint_time / total_core_time * 100:.1f}%)" + ) + mprint(f"Total core processing: {total_core_time:.2f}s") + mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") + mprint(f"=========================\n") + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Arguments for single checkpoint creation + parser.add_argument("--parent_checkpoint_dir", type=str, required=True) + parser.add_argument("--model_config_overrides_json", type=str, required=True) + parser.add_argument("--output_checkpoint_dir", type=str, required=True) + parser.add_argument( + "--gqa_init_mode", type=str, default="AverageKV", choices=GQAInitMode._member_names_ + ) + parser.add_argument( + "--mlp_init_mode", type=str, default="Truncate", choices=MlpInitMode._member_names_ + ) + parser.add_argument("--mlp_init_config_yaml", type=str, default=None) + parser.add_argument( + "--linear_init_mode", type=str, default="FromTeacher", choices=LinearInitMode._member_names_ + ) + parser.add_argument( + "--hidden_size_init_mode", type=str, default=None, choices=HiddenSizeInitMode._member_names_ + ) + parser.add_argument("--channel_importance_path", type=str, required=False) + parser.add_argument("--target_hidden_sizes", type=int, nargs="+", required=False) + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + + init_child_from_parent( + parent_checkpoint_dir=args.parent_checkpoint_dir, + model_config_overrides_json=args.model_config_overrides_json, + output_checkpoint_dir=args.output_checkpoint_dir, + gqa_init_mode=GQAInitMode(args.gqa_init_mode), + mlp_init_mode=MlpInitMode(args.mlp_init_mode), + mlp_init_config_yaml=args.mlp_init_config_yaml, + linear_init_mode=LinearInitMode(args.linear_init_mode), + hidden_size_init_mode=HiddenSizeInitMode(args.hidden_size_init_mode) + if args.hidden_size_init_mode + else None, + ) diff --git a/modelopt/torch/_compress/tools/kd_model.py b/modelopt/torch/_compress/tools/kd_model.py index 437eb51ca2..8590c3f56c 100644 --- a/modelopt/torch/_compress/tools/kd_model.py +++ b/modelopt/torch/_compress/tools/kd_model.py @@ -22,11 +22,11 @@ # mypy: ignore-errors from abc import ABCMeta, abstractmethod -from typing import List, Callable, Literal, Tuple, Optional +from typing import Callable, List, Literal, Optional, Tuple import torch import torch.nn.functional as F -from torch import nn, Tensor +from torch import Tensor, nn def normalized_mse_loss( diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 549ee9a88c..8d1a222c89 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -29,7 +29,6 @@ import torch.distributed import torch.nn as nn from huggingface_hub import split_torch_state_dict_into_shards -from modelopt.torch._compress.tools.logger import mprint from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index 0e745a0646..47e8e4202d 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -32,13 +32,6 @@ PreTrainedModel, PreTrainedTokenizerBase, ) -from modelopt.torch._compress.utils.data.dataloaders import create_validation_dataloader -from modelopt.torch._compress.utils.parsing import simple_parse_args_string -from modelopt.torch._compress.utils.validate_runtime_pipeline import ( - HiddenStatesAndLMHead, - calculate_losses_pipeline, -) -from modelopt.torch._compress.utils.validation import calculate_losses from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( register_activation_hooks, @@ -47,6 +40,13 @@ from modelopt.torch._compress.tools.logger import aprint, mprint from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch._compress.utils.data.dataloaders import create_validation_dataloader +from modelopt.torch._compress.utils.parsing import simple_parse_args_string +from modelopt.torch._compress.utils.validate_runtime_pipeline import ( + HiddenStatesAndLMHead, + calculate_losses_pipeline, +) +from modelopt.torch._compress.utils.validation import calculate_losses # #TODO:Import slack from root utils directory # root_path = os.path.join(os.path.dirname(__file__), "..", "..") diff --git a/modelopt/torch/_compress/utils/checkpoint_manager.py b/modelopt/torch/_compress/utils/checkpoint_manager.py index 318586ba44..b96fd21a56 100644 --- a/modelopt/torch/_compress/utils/checkpoint_manager.py +++ b/modelopt/torch/_compress/utils/checkpoint_manager.py @@ -20,8 +20,9 @@ import json import time from pathlib import Path -from typing import Dict, Any, Optional -from modelopt.torch._compress.tools.logger import mprint, aprint +from typing import Any, Dict, Optional + +from modelopt.torch._compress.tools.logger import aprint, mprint class ScoringCheckpointManager: diff --git a/modelopt/torch/_compress/utils/data/dataloaders.py b/modelopt/torch/_compress/utils/data/dataloaders.py index 4c4fce0606..584e32480b 100644 --- a/modelopt/torch/_compress/utils/data/dataloaders.py +++ b/modelopt/torch/_compress/utils/data/dataloaders.py @@ -26,11 +26,12 @@ import torch import torch.distributed from accelerate import Accelerator -from modelopt.torch._compress.tools.logger import mprint from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.utils.data._utils.collate import collate, default_collate_fn_map from tqdm import tqdm from transformers import PreTrainedTokenizerBase + +from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.utils.data.dataset import ConstantLengthDataset diff --git a/modelopt/torch/_compress/utils/data/dataset.py b/modelopt/torch/_compress/utils/data/dataset.py index 2c7fcef09a..342b0821ef 100644 --- a/modelopt/torch/_compress/utils/data/dataset.py +++ b/modelopt/torch/_compress/utils/data/dataset.py @@ -14,8 +14,7 @@ # limitations under the License. # mypy: ignore-errors import functools -from typing import Optional -from typing import Sequence +from typing import Optional, Sequence import numpy as np import torch diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py index 08e1221a72..aa8a4f304b 100644 --- a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -29,14 +29,14 @@ import torch import torch.distributed import wandb -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.checkpoint_utils import init_module_with_state_dict +from torch.utils.data import DataLoader +from tqdm import tqdm + from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMForCausalLM, LMHead, ) -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.sewing_kit import ( ExternalTarget, InputArgs, @@ -51,8 +51,9 @@ distributed_send_obj, fake_tensor, ) -from torch.utils.data import DataLoader -from tqdm import tqdm +from modelopt.torch._compress.tools.checkpoint_utils import init_module_with_state_dict +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.tools.sharded_checkpoint_utils import DummyBlock from modelopt.torch._compress.utils.validation import _organize_outputs, calculate_batch_outputs diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py index 63c6642248..662ae4a2b6 100644 --- a/modelopt/torch/_compress/utils/validation.py +++ b/modelopt/torch/_compress/utils/validation.py @@ -32,12 +32,13 @@ import torch.nn.functional as F import wandb from accelerate import Accelerator -from modelopt.torch._compress.tools import kd_model from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper from typing_extensions import Self + +from modelopt.torch._compress.tools import kd_model from modelopt.torch._compress.utils.data.dataloaders import create_padded_tensor diff --git a/pyproject.toml b/pyproject.toml index 7694cbaf3e..051891bcf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ extend-ignore = [ "*/_[a-zA-Z]*" = ["D"] # Private packages (_abc/*.py) or modules (_xyz.py) "*.ipynb" = ["D", "E501"] # Ignore missing docstrings or line length for Jupyter notebooks "modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style -"modelopt/torch/_compress/*" = ["C4", "D", "E", "F", "FURB", "I", "ISC", "N", "PERF", "PGH", "PIE", "PLE", "PLR", "PT", "RUF", "SIM", "TC", "UP", "W"] # TODO:Disabled for now, will enable later, once all puzzletron code is migrated +"modelopt/torch/_compress/*" = ["C4", "D", "E", "F", "FURB", "ISC", "N", "PERF", "PGH", "PIE", "PLE", "PLR", "PT", "RUF", "SIM", "TC", "UP", "W"] # TODO:Disabled for now, will enable later, once all puzzletron code is migrated [tool.ruff.lint.pycodestyle] diff --git a/setup.py b/setup.py index 18e22feba2..ab70cdf68a 100644 --- a/setup.py +++ b/setup.py @@ -107,6 +107,7 @@ "omegaconf==2.3.0", "wandb~=0.17.5", "lru-dict", + "typeguard", ], } From 97fe7f0a0cc90e6f8ac275e0225b850d6c503d80 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 1 Dec 2025 13:47:44 +0100 Subject: [PATCH 18/62] Add build replacement library to the compress algorithm. (#616) ## What does this PR do? Add build replacement library to the compress algorithm. --------- Signed-off-by: Daniel Korzekwa --- .../_compress/build_library_and_stats.py | 109 ++++ .../nas/plugins/compress_nas_plugin.py | 3 +- .../build_replacement_library.py | 605 ++++++++++++++++++ .../replacement_library.py | 388 +++++++++++ .../replacement_library/replacement_utils.py | 122 ++++ modelopt/torch/_compress/utils/utils.py | 98 +++ setup.py | 2 + 7 files changed, 1326 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/_compress/build_library_and_stats.py create mode 100644 modelopt/torch/_compress/replacement_library/build_replacement_library.py create mode 100644 modelopt/torch/_compress/replacement_library/replacement_library.py create mode 100644 modelopt/torch/_compress/replacement_library/replacement_utils.py diff --git a/modelopt/torch/_compress/build_library_and_stats.py b/modelopt/torch/_compress/build_library_and_stats.py new file mode 100644 index 0000000000..19bd4f03cc --- /dev/null +++ b/modelopt/torch/_compress/build_library_and_stats.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified command that runs build_replacement_library followed by calc_subblock_stats. + +This script combines the functionality of both commands into a single workflow: +1. First, it builds the replacement library for the puzzle +2. Then, it calculates subblock statistics + +Usage: + + python modelopt.torch._compress.build_library_and_stats.py --config-dir configs --config-name Llama-3_1-8B puzzle_dir=/path/to/puzzle/dir dataset_path=/path/to/dataset + +The script uses the same Hydra configuration as the individual commands and supports +all the same configuration parameters for both build_replacement_library and calc_subblock_stats. +""" + +import hydra +from calc_subblock_stats import launch_calc_subblock_stats +from omegaconf import DictConfig + +from modelopt.torch._compress.replacement_library.build_replacement_library import ( + launch_build_replacement_library, +) +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.utils.parsing import format_global_config + + +def launch_build_library_and_stats(cfg: DictConfig) -> None: + """ + Launch both build_replacement_library and calc_subblock_stats in sequence. + + Args: + cfg: Hydra configuration containing settings for both commands + """ + mprint("=" * 80) + mprint("STARTING UNIFIED BUILD LIBRARY AND STATS WORKFLOW") + mprint("=" * 80) + + # Step 1: Build replacement library + mprint("=" * 50) + mprint("STEP 1: Building Replacement Library") + mprint("=" * 50) + + try: + launch_build_replacement_library(cfg) + mprint("✅ Replacement library built successfully!") + except Exception as e: + mprint(f"❌ Failed to build replacement library: {e}") + raise + + # Step 2: Calculate subblock statistics + mprint("=" * 50) + mprint("STEP 2: Calculating Subblock Statistics") + mprint("=" * 50) + + try: + launch_calc_subblock_stats(cfg) + mprint("✅ Subblock statistics calculated successfully!") + except Exception as e: + mprint(f"❌ Failed to calculate subblock statistics: {e}") + raise + + mprint("=" * 80) + mprint("UNIFIED WORKFLOW COMPLETED SUCCESSFULLY! 🎉") + mprint("=" * 80) + + mprint("Generated files:") + mprint(f" - {cfg.puzzle_dir}/block_library.json") + mprint(f" - {cfg.puzzle_dir}/subblock_library.json") + mprint(f" - {cfg.puzzle_dir}/replacement_library.json") + mprint(f" - {cfg.puzzle_dir}/single_sequence_replacement_solutions.json") + mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.subblock_stats_filename}") + if hasattr(cfg.calc_subblock_stats, "moe_stats_filename"): + mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.moe_stats_filename}") + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + """ + Main entry point for the unified build library and stats command. + + This function uses Hydra for configuration management and runs both + build_replacement_library and calc_subblock_stats in sequence. + """ + cfg = hydra.utils.instantiate(cfg) + mprint("Unified Build Library and Stats Configuration:") + mprint(format_global_config(cfg)) + launch_build_library_and_stats(cfg) + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 8fbf7c7c47..72c40f729f 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -23,13 +23,13 @@ import datetime from pathlib import Path -import build_library_and_stats import mip_and_realize_models import scoring import torch from torch import nn import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts +from modelopt.torch._compress import build_library_and_stats from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, @@ -123,6 +123,7 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR ) # Convert Llama3 model to DeciLM model + # TODO: Make it generic, do not call convert_llama3_to_decilm directly. if runtime.global_rank == 0: mprint("Compress Progress 2/8: converting model from HF to DeciLM (single-gpu)") hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable diff --git a/modelopt/torch/_compress/replacement_library/build_replacement_library.py b/modelopt/torch/_compress/replacement_library/build_replacement_library.py new file mode 100644 index 0000000000..a8b2b7f9b6 --- /dev/null +++ b/modelopt/torch/_compress/replacement_library/build_replacement_library.py @@ -0,0 +1,605 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module constructs the replacement library JSON files from a puzzle directory containing +multiple trained model checkpoints. It analyzes checkpoints to extract unique block and subblock +configurations, builds a library of available replacements, and generates solutions for layer +replacement in compressed models. The resulting replacement library can then be used by +ReplacementLibrary to efficiently load models with mixed teacher/student layers. + +Standard Puzzle Usage: +====================== +python -m modelopt.torch._compress.replacement_library.build_replacement_library PUZZLE_DIR + +Teacher checkpoint dir is assumed to be inside PUZZLE_DIR/ckpts/teacher (symlink is recommended) +though you can supply an explicit --teacher_checkpoint_dir. + +--add_ffn_no_ops and --add_attention_no_ops are optional (default True), + + +Untrained puzzle run (with bypass): +=================================== +The subblock that doesn't interest you in the checkpoint should be no_op. + +""" +# mypy: ignore-errors + +import json +from pathlib import Path +from typing import Any, Type + +import hydra +import pandas as pd +from omegaconf import DictConfig + +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) +from modelopt.torch._compress.replacement_library.replacement_utils import ( + is_replacement_identical_to_teacher, + replacement_is_teacher, + sort_replacements, +) +from modelopt.torch._compress.tools.checkpoint_utils import ( + SAFETENSORS_SUBBLOCKS_DIR_NAME, + is_valid_decilm_checkpoint, + load_model_config, +) +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.robust_json import json_dump +from modelopt.torch._compress.utils.parsing import format_global_config +from modelopt.torch._compress.utils.utils import block_config_to_str, subblock_config_to_str + +UNIQUE_SUBBLOCK_IDENTIFIER = ["block_config", "attention_config", "ffn_config", "block_idx"] +CHECKPOINTS_DIR_NAME = "ckpts" + + +def build_replacement_library( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str | None = None, + add_ffn_no_ops: bool = True, + add_attention_no_ops: bool = True, +) -> None: + """ + For normal puzzle runs, use default values. + For advanced use cases, see the Usage section. + """ + master_puzzle_dir = Path(master_puzzle_dir) + (master_puzzle_dir / "ckpts").mkdir(exist_ok=True) + teacher_checkpoint_dir = infer_teacher_dir(master_puzzle_dir, teacher_checkpoint_dir) + subblocks_df = _build_subblocks_df( + master_puzzle_dir, + teacher_checkpoint_dir, + add_ffn_no_ops, + add_attention_no_ops, + ) + block_library_df = _build_block_library_from_subblocks(subblocks_df) + + layer_replacements = _build_layer_replacements( + block_library_df, master_puzzle_dir, teacher_checkpoint_dir + ) + + single_sequence_replacement_solutions = _build_single_sequence_replacement_solutions( + layer_replacements, teacher_checkpoint_dir + ) + + json_dump(block_library_df.to_dict(orient="records"), master_puzzle_dir / "block_library.json") + json_dump(subblocks_df.to_dict(orient="records"), master_puzzle_dir / "subblock_library.json") + json_dump(layer_replacements, master_puzzle_dir / "replacement_library.json") + json_dump( + single_sequence_replacement_solutions, + master_puzzle_dir / "single_sequence_replacement_solutions.json", + ) + mprint("done") + + +def launch_build_replacement_library(cfg: DictConfig) -> None: + """ + Launch the build replacement library function with Hydra configuration. + """ + mprint(f"Building replacement library for puzzle directory: {cfg.puzzle_dir}") + mprint(f"Teacher directory: {cfg.teacher_dir}") + mprint( + f"Build replacement library config: {format_global_config(cfg.build_replacement_library, title='Build replacement library')}" + ) + + build_replacement_library( + master_puzzle_dir=cfg.puzzle_dir, + teacher_checkpoint_dir=cfg.teacher_dir, + add_ffn_no_ops=cfg.build_replacement_library.add_ffn_no_ops, + add_attention_no_ops=cfg.build_replacement_library.add_attention_no_ops, + ) + + +def infer_teacher_dir( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str | None = None, +) -> Path: + if teacher_checkpoint_dir is None: + teacher_checkpoint_dir = Path(master_puzzle_dir) / CHECKPOINTS_DIR_NAME / "teacher" + if not teacher_checkpoint_dir.exists(): + raise ValueError( + f"You must either provide the --teacher_checkpoint_dir argument, or create a link to the " + f"teacher dir under '{{PUZZLE_DIR}}/ckpts'." + ) + teacher_checkpoint_dir = Path(teacher_checkpoint_dir).resolve().absolute() + return teacher_checkpoint_dir + + +def _build_block_library_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame: + joint_blocks_df = subblocks_df.dropna(subset=["block_config"]).copy() + constructed_blocks_df = _construct_blocks_from_subblocks(subblocks_df) + + is_constructed_block_has_joint_variant = pd.Series( + map(tuple, constructed_blocks_df[["block_config", "block_idx"]].values) + ).isin(pd.Series(map(tuple, joint_blocks_df[["block_config", "block_idx"]].values))) + constructed_blocks_df = constructed_blocks_df[~is_constructed_block_has_joint_variant] + + block_library_df = pd.concat([joint_blocks_df, constructed_blocks_df]) + block_library_df["block_repr"] = block_library_df["block_config"].apply(block_config_to_str) + + dups = block_library_df.loc[ + block_library_df[["block_config", "block_idx"]].duplicated() + ].sort_values(by=["block_config", "block_idx"]) + if len(dups) > 0: + mprint(f"Found {len(dups)} duplicate blocks in the block library. Here are some examples:") + dup_block_idx = dups["block_idx"].iloc[0] + dups_with_same_block_idx = dups[dups["block_idx"] == dup_block_idx] + for _, row in dups_with_same_block_idx.head(10).iterrows(): + mprint(row.to_dict()) + json_dump(block_library_df.to_dict(orient="records"), "ERROR_block_library.json") + json_dump(subblocks_df.to_dict(orient="records"), "ERROR_subblock_library.json") + raise ValueError( + f"Found {len(dups)} duplicate blocks in the block library. See ERROR_block_library.json and ERROR_subblock_library.json for more details." + ) + + return block_library_df + + +def _construct_blocks_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame: + columns = subblocks_df.columns + decomp_blocks_df = subblocks_df[subblocks_df["block_config"].isna()].drop( + columns=columns[columns.str.contains("block_config|joint|block_repr")] + ) + + attention_df = decomp_blocks_df.dropna(subset="attention_config").drop( + columns=columns[columns.str.contains("ffn")] + ) + ffn_df = decomp_blocks_df.dropna(subset="ffn_config").drop( + columns=columns[columns.str.contains("attention")] + ) + constructed_blocks_df = pd.merge(attention_df, ffn_df, on="block_idx") + + constructed_blocks_df["block_config"] = constructed_blocks_df.apply( + lambda row: BlockConfig(ffn=row["ffn_config"], attention=row["attention_config"]), axis=1 + ) + + return constructed_blocks_df + + +def _build_subblocks_df( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str, + add_ffn_no_ops: bool, + add_attention_no_ops: bool, +) -> pd.DataFrame: + teacher_checkpoint_dir = Path(teacher_checkpoint_dir) + checkpoint_dirs = _get_last_checkpoint_from_each_experiment(master_puzzle_dir) + checkpoint_dirs = [teacher_checkpoint_dir] + list(checkpoint_dirs - {teacher_checkpoint_dir}) + checkpoints_to_split = [teacher_checkpoint_dir] + + subblock_rows = [] + for checkpoint_dir in checkpoint_dirs: + subblocks_to_extract = _infer_subblocks_to_extract(checkpoint_dir, checkpoints_to_split) + if len(subblocks_to_extract) > 0: + subblock_rows_from_current_checkpoint = ( + _construct_subblock_rows_from_current_checkpoint( + checkpoint_dir, subblocks_to_extract + ) + ) + subblock_rows.extend(subblock_rows_from_current_checkpoint) + + subblocks_df = pd.DataFrame(subblock_rows) + + subblocks_df = _drop_duplicates_of_decomp_no_op(subblocks_df) + assert subblocks_df.duplicated().sum() == 0 + + if add_ffn_no_ops or add_attention_no_ops: + subblocks_df = _add_no_op_subblock_rows(subblocks_df, add_ffn_no_ops, add_attention_no_ops) + + subblocks_df = _drop_duplicates_of_teacher(subblocks_df, teacher_checkpoint_dir) + + subblocks_that_have_multiple_sources = list( + subblocks_df[subblocks_df.duplicated(UNIQUE_SUBBLOCK_IDENTIFIER, keep=False)].groupby( + UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False + ) + ) + if len(subblocks_that_have_multiple_sources) > 0: + mprint( + f"Found {len(subblocks_that_have_multiple_sources)} subblock types with multiple sources. Dropping duplicates..." + ) + for subblock_identifier, duplicates_df in subblocks_that_have_multiple_sources: + mprint("\n================================") + mprint(dict(zip(UNIQUE_SUBBLOCK_IDENTIFIER, subblock_identifier))) + for _, row in duplicates_df.iterrows(): + mprint(row.to_dict()) + + # Drop duplicates, keeping the first occurrence (which should be from teacher) + mprint(f"Dropping duplicates. Original count: {len(subblocks_df)}") + subblocks_df = subblocks_df.drop_duplicates(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first") + mprint(f"After dropping duplicates: {len(subblocks_df)}") + + subblocks_df["ffn_repr"] = subblocks_df["ffn_config"].apply(subblock_config_to_str) + subblocks_df["attention_repr"] = subblocks_df["attention_config"].apply(subblock_config_to_str) + subblocks_df["block_repr"] = subblocks_df["block_config"].apply(block_config_to_str) + + return subblocks_df + + +def _drop_duplicates_of_teacher( + subblocks_df: pd.DataFrame, + teacher_checkpoint_dir: Path | str, +) -> pd.DataFrame: + orig_subblocks_df = subblocks_df.copy() + + attention_is_teacher = subblocks_df["attention_checkpoint_dir"] == str(teacher_checkpoint_dir) + ffn_is_teacher = subblocks_df["ffn_checkpoint_dir"] == str(teacher_checkpoint_dir) + is_joint_teacher = attention_is_teacher & ffn_is_teacher + + is_decomp_attention = subblocks_df["ffn_config"].isna() + is_decomp_ffn = subblocks_df["attention_config"].isna() + is_joint_block = ~is_decomp_attention & ~is_decomp_ffn + + student_indices_that_have_teacher_dups = [] + + for current_subset, is_teacher in [ + (is_decomp_attention, attention_is_teacher), + (is_decomp_ffn, ffn_is_teacher), + (is_joint_block, is_joint_teacher), + ]: + subblocks_df = orig_subblocks_df.copy().loc[current_subset] + + subblocks_df["is_student"] = ~is_teacher.loc[current_subset] + + def get_student_indices_that_have_teacher_dups(grouped_is_student: pd.Series) -> list: + if grouped_is_student.all(): + return [] + return grouped_is_student.index[grouped_is_student].tolist() + + current_student_indices_that_have_teacher_dups = [ + dup_index + for dup_list in subblocks_df.groupby(UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False)[ + "is_student" + ].apply(get_student_indices_that_have_teacher_dups) + for dup_index in dup_list + ] + student_indices_that_have_teacher_dups.extend( + current_student_indices_that_have_teacher_dups + ) + + dedup_subblocks_df = orig_subblocks_df.drop(index=student_indices_that_have_teacher_dups) + return dedup_subblocks_df + + +def _drop_duplicates_of_decomp_no_op(subblocks_df: pd.DataFrame) -> pd.DataFrame: + is_decomp = subblocks_df["block_config"].isna() + is_ffn_no_op = subblocks_df["ffn_config"].apply(lambda conf: conf is not None and conf.no_op) + is_attention_no_op = subblocks_df["attention_config"].apply( + lambda conf: conf is not None and conf.no_op + ) + is_duplicated = subblocks_df.duplicated(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first") + is_dup_of_decomp_no_op = is_duplicated & is_decomp & (is_ffn_no_op | is_attention_no_op) + subblocks_df = subblocks_df[~is_dup_of_decomp_no_op] + return subblocks_df + + +def _construct_subblock_rows_from_current_checkpoint( + checkpoint_dir: Path, subblocks_to_extract: list[str] +) -> list[dict[str, Any]]: + subblock_rows_from_current_checkpoint = [] + model_config = load_model_config(checkpoint_dir) + for block_idx, block_config in enumerate(model_config.block_configs): + for subblock_to_extract in subblocks_to_extract: + subblock_row = _init_empty_subblock_row(block_idx) + + if subblock_to_extract == "block": + subblock_row["block_config"] = block_config + subblock_row["attention_config"] = block_config.attention + subblock_row["attention_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.attention.no_op else None + ) + subblock_row["ffn_config"] = block_config.ffn + subblock_row["ffn_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.ffn.no_op else None + ) + elif subblock_to_extract == "ffn": + subblock_row["ffn_config"] = block_config.ffn + subblock_row["ffn_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.ffn.no_op else None + ) + elif subblock_to_extract == "attention": + subblock_row["attention_config"] = block_config.attention + subblock_row["attention_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.attention.no_op else None + ) + else: + raise ValueError() + + subblock_rows_from_current_checkpoint.append(subblock_row) + return subblock_rows_from_current_checkpoint + + +def _add_no_op_subblock_rows( + subblocks_df: pd.DataFrame, + add_ffn_no_op: bool, + add_attention_no_op: bool, +) -> pd.DataFrame: + n_layer = subblocks_df["block_idx"].max() + 1 + + no_op_subblocks = [] + if add_ffn_no_op: + no_op_subblocks.append("ffn") + if add_attention_no_op: + no_op_subblocks.append("attention") + + additional_no_op_rows = [] + for no_op_subblock in no_op_subblocks: + rows_with_no_op_subblock, subblock_cls = _get_rows_with_no_op_subblock( + subblocks_df, no_op_subblock + ) + existing_no_op_indices = rows_with_no_op_subblock["block_idx"].values + missing_no_op_indices = list(set(range(n_layer)) - set(existing_no_op_indices)) + for block_idx in missing_no_op_indices: + no_op_subblock_row = { + **_init_empty_subblock_row(block_idx), + f"{no_op_subblock}_config": subblock_cls(no_op=True), + } + additional_no_op_rows.append(no_op_subblock_row) + + subblocks_df = pd.concat([subblocks_df, pd.DataFrame(additional_no_op_rows)]) + + for no_op_subblock in no_op_subblocks: + rows_with_no_op_subblock, _ = _get_rows_with_no_op_subblock(subblocks_df, no_op_subblock) + assert len(rows_with_no_op_subblock) == n_layer, ( + f"Got {len(rows_with_no_op_subblock)} rows with {no_op_subblock}=no_op, but we have {n_layer} layers" + ) + return subblocks_df + + +def _get_rows_with_no_op_subblock( + subblocks_df: pd.DataFrame, no_op_subblock: str +) -> tuple[pd.DataFrame, Type[AttentionConfig] | Type[FFNConfig]]: + other_subblock = "ffn" if no_op_subblock == "attention" else "attention" + subblock_cls = AttentionConfig if no_op_subblock == "attention" else FFNConfig + no_op_subblock_config = subblock_cls(no_op=True) + rows_with_no_op_subblock = subblocks_df[ + (subblocks_df[f"{no_op_subblock}_config"] == no_op_subblock_config) + & subblocks_df[f"{other_subblock}_config"].isna() + ] + return rows_with_no_op_subblock, subblock_cls + + +def _get_last_checkpoint_from_each_experiment(master_puzzle_dir: Path | str) -> set[Path]: + master_puzzle_dir = Path(master_puzzle_dir) + master_checkpoints_dir = master_puzzle_dir / CHECKPOINTS_DIR_NAME + subdirs_of_master_checkpoints_dir = [ + p.resolve() for p in master_checkpoints_dir.iterdir() if p.is_dir() + ] + checkpoint_dirs = [ + p.parent + for subdir in subdirs_of_master_checkpoints_dir + for p in subdir.rglob("config.json") + ] + + for checkpoint_dir in checkpoint_dirs: + if checkpoint_dir == master_checkpoints_dir: + raise ValueError( + f"We need at least 1 hierarchy level under the '{CHECKPOINTS_DIR_NAME}' dir. " + "Name your checkpoints, preferably with meaningful names. " + "If you are Ido Galil, tell Tomer that you got this exception ;) " + ) + + # Filter out non-DeciLM checkpoints (e.g., unconverted Llama checkpoints) + valid_checkpoint_dirs = [cp for cp in checkpoint_dirs if is_valid_decilm_checkpoint(cp)] + + experiment_dirs = [ + p if (p in subdirs_of_master_checkpoints_dir) else p.parent for p in valid_checkpoint_dirs + ] + + deduped_checkpoint_dirs = set( + pd.DataFrame({"checkpoint_dir": valid_checkpoint_dirs, "experiment_dir": experiment_dirs}) + .sort_values("checkpoint_dir") + .drop_duplicates(subset="experiment_dir", keep="last")["checkpoint_dir"] + .tolist() + ) + return deduped_checkpoint_dirs + + +def _infer_subblocks_to_extract( + checkpoint_dir: Path, + checkpoints_to_split: list[Path], +) -> list[str]: + if (checkpoint_dir / "replacement_library.json").exists(): + return [] + bypass_config_path = checkpoint_dir / "bypass_config.json" + if (checkpoint_dir in checkpoints_to_split) or (not bypass_config_path.exists()): + subblocks_to_extract = ["block", "attention", "ffn"] + else: + bypass_config = json.loads(bypass_config_path.read_text()) + keys_to_learn = bypass_config.get("keys_to_learn", "entire_block") + if keys_to_learn == "entire_block": + subblocks_to_extract = ["block"] + elif "mlp" in keys_to_learn and "attn" not in keys_to_learn: + subblocks_to_extract = ["ffn"] + elif "attn" in keys_to_learn and "mlp" not in keys_to_learn: + subblocks_to_extract = ["attention"] + else: + raise ValueError(f"Unrecognized {keys_to_learn=}") + return subblocks_to_extract + + +def _init_empty_subblock_row(block_idx: int) -> dict[str, Any]: + return { + "attention_checkpoint_dir": None, + "ffn_checkpoint_dir": None, + "block_config": None, + "attention_config": None, + "ffn_config": None, + "block_idx": block_idx, + "block_repr": None, + "attention_repr": None, + "ffn_repr": None, + } + + +def _build_layer_replacements( + block_library_df: pd.DataFrame, + master_puzzle_dir: Path, + teacher_checkpoint_dir: Path, +) -> list[dict]: + layer_replacements_from_blocks = _build_layer_replacements_from_block_library(block_library_df) + layer_replacements_from_checkpoints = _gather_layer_replacements_from_checkpoints( + master_puzzle_dir + ) + layer_replacements = layer_replacements_from_blocks + layer_replacements_from_checkpoints + layer_replacements = _filter_duplicate_teacher_replacements( + layer_replacements, teacher_checkpoint_dir + ) + return layer_replacements + + +def _build_layer_replacements_from_block_library(block_library_df: pd.DataFrame) -> list[dict]: + layer_replacements = [] + for _, row in block_library_df.iterrows(): + block_idx = row["block_idx"] + block_config = row["block_config"] + weight_paths = [] + for subblock_name in ["attention", "ffn"]: + checkpoint_dir = row[f"{subblock_name}_checkpoint_dir"] + if checkpoint_dir is not None: + subblock_path = ( + Path(checkpoint_dir) + / SAFETENSORS_SUBBLOCKS_DIR_NAME + / f"block_{block_idx}_{subblock_name}.safetensors" + ) + weight_paths.append(subblock_path) + weight_paths = sorted(set(weight_paths)) + layer_replacement = { + "parent_layer_indices": [block_idx], + "child_block_configs": [block_config], + "weight_paths": weight_paths, + } + layer_replacements.append(layer_replacement) + return layer_replacements + + +def _gather_layer_replacements_from_checkpoints(master_puzzle_dir: str | Path) -> list[dict]: + gathered_layer_replacements = [] + checkpoint_dirs = _get_last_checkpoint_from_each_experiment(master_puzzle_dir) + for checkpoint_dir in checkpoint_dirs: + if (layer_replacements_path := checkpoint_dir / "replacement_library.json").exists(): + layer_replacements = json.loads(layer_replacements_path.read_text()) + for layer_replacement in layer_replacements: + layer_replacement["child_block_configs"] = [ + BlockConfig(**block_config_dict) + for block_config_dict in layer_replacement["child_block_configs"] + ] + layer_replacement["weight_paths"] = sorted( + set(Path(p) for p in layer_replacement["weight_paths"]) + ) + gathered_layer_replacements.extend(layer_replacements) + return gathered_layer_replacements + + +def _filter_duplicate_teacher_replacements( + layer_replacements: list[dict], + teacher_checkpoint_dir: Path, +) -> list[dict]: + teacher_model_config = load_model_config(teacher_checkpoint_dir) + filtered_layer_replacements = [] + for layer_replacement in layer_replacements: + if replacement_is_teacher( + layer_replacement, teacher_model_config, teacher_checkpoint_dir + ) or not is_replacement_identical_to_teacher(layer_replacement, teacher_model_config): + filtered_layer_replacements.append(layer_replacement) + return filtered_layer_replacements + + +def _build_single_sequence_replacement_solutions( + layer_replacements: list[dict], + teacher_checkpoint_dir: Path, +) -> list[dict]: + teacher_model_config = load_model_config(teacher_checkpoint_dir) + n_layer = teacher_model_config.num_hidden_layers + + teacher_replacements = dict() + student_replacements = [] + for layer_replacement in layer_replacements: + if replacement_is_teacher(layer_replacement, teacher_model_config, teacher_checkpoint_dir): + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_replacements[block_idx] = layer_replacement + else: + student_replacements.append(layer_replacement) + + teacher_indices_represented_in_replacements = sorted(teacher_replacements.keys()) + assert teacher_indices_represented_in_replacements == list(range(n_layer)), ( + f"{n_layer=}, {teacher_indices_represented_in_replacements=}" + ) + + student_replacements = sort_replacements(student_replacements) + + solutions = [] + for layer_replacement in student_replacements: + block_indices_not_represented_in_replacement = sorted( + set(range(n_layer)) - set(layer_replacement["parent_layer_indices"]) + ) + chosen_replacements = sort_replacements( + [layer_replacement] + + [ + teacher_replacements[block_idx] + for block_idx in block_indices_not_represented_in_replacement + ] + ) + + block_configs = [ + block_config + for replacement in chosen_replacements + for block_config in replacement["child_block_configs"] + ] + + solutions.append( + { + "single_sequence_replacement": layer_replacement, + "chosen_replacements": chosen_replacements, + "block_configs": block_configs, + } + ) + + return solutions + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(format_global_config(cfg)) + launch_build_replacement_library(cfg) + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/replacement_library/replacement_library.py b/modelopt/torch/_compress/replacement_library/replacement_library.py new file mode 100644 index 0000000000..ccfaaee0de --- /dev/null +++ b/modelopt/torch/_compress/replacement_library/replacement_library.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Replacement library for efficiently loading and managing layer-replaced DeciLM models. +- Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations +""" +# mypy: ignore-errors + +import json +import re +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +from immutabledict import immutabledict +from lru import LRU +from safetensors.torch import load_file as safe_load_file +from torch import nn + +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMDecoderLayer, + DeciLMForCausalLM, + DeciLMMultiDecoderLayer, + DeciLMRMSNorm, + LMHead, +) +from modelopt.torch._compress.replacement_library.replacement_utils import ( + extract_block_configs_and_locations, + parse_layer_replacement, + sort_replacements, + weights_path_to_checkpoint_dir, +) +from modelopt.torch._compress.tools.checkpoint_utils import ( + PTH_SUBBLOCKS_DIR_NAME, + SAFETENSORS_SUBBLOCKS_DIR_NAME, + infer_weights_dtype, + init_empty_module, + init_module_with_state_dict, + load_model_config, +) +from modelopt.torch._compress.tools.sharded_checkpoint_utils import ( + create_dummy_model, + is_in_safetensors_format, + load_sharded_state_dict, +) + + +class ReplacementLibrary: + def __init__( + self, + replacement_library_path: str | Path, + model_config_overrides: Optional[dict] = None, + ): + self.replacement_library = self._load_replacement_library(replacement_library_path) + self._ensure_all_checkpoints_are_split() + self.model_config_overrides = ( + immutabledict(model_config_overrides) if (model_config_overrides is not None) else None + ) + + self._loaded_replacements: dict[str, nn.ModuleList] = LRU( + size=256 + ) # least-recently-used dict: a dict of fixed size that evicts old items + + self._dtype = None + + self.teacher_dir = Path(replacement_library_path).parent / "ckpts" / "teacher" + self._model_config = None + self._embedding = None + self._ln_f = None + self._lm_head = None + self._arbitrary_checkpoint_dir = None + + @staticmethod + def _load_replacement_library(replacement_library_path: str | Path) -> list[dict]: + replacement_library = json.loads(Path(replacement_library_path).read_text()) + replacement_library = [ + parse_layer_replacement(layer_replacement) for layer_replacement in replacement_library + ] + return replacement_library + + def _ensure_all_checkpoints_are_split(self) -> None: + checkpoint_dirs = self._get_all_checkpoint_dirs() + unsplit_checkpoints = [] + for checkpoint_dir in checkpoint_dirs: + if not (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): + unsplit_checkpoints.append(checkpoint_dir) + assert len(unsplit_checkpoints) == 0, f"Found unsplit checkpoints: {unsplit_checkpoints}" + + @property + def dtype(self) -> torch.dtype: + if self._dtype is None: + ln_f = self.get_ln_f() + self._dtype = ln_f.weight.dtype + return self._dtype + + @property + def n_layer(self) -> int: + return self.model_config.get_num_hidden_layers() + + @property + def model_config(self) -> DeciLMConfig: + if self._model_config is None: + self._model_config = load_model_config( + self.get_arbitrary_checkpoint_dir(), self.model_config_overrides + ) + return self._model_config + + def create_model_config(self, layer_replacements: list[dict]): + block_configs, _ = extract_block_configs_and_locations(layer_replacements) + model_config = self.model_config.set_block_configs(block_configs) + return model_config + + def load_model( + self, + layer_replacements: list[dict], + world_size: int, + global_rank: int, + ) -> DeciLMForCausalLM: + block_configs, block_locations = extract_block_configs_and_locations(layer_replacements) + model_config = self.model_config.set_block_configs(block_configs) + + owned_block_indexes = _get_owned_block_indexes( + model_config.get_num_hidden_layers(), world_size, global_rank + ) + model = create_dummy_model(model_config, self.dtype) + + is_first_shard = 0 in owned_block_indexes + if is_first_shard and not isinstance(model.model.get_input_embeddings(), nn.Embedding): + model.set_input_embeddings(self.get_embedding()) + + is_last_shard = model_config.get_num_hidden_layers() - 1 in owned_block_indexes + if is_last_shard and not isinstance(model.model.get_output_embeddings(), nn.Linear): + model.model.set_final_layer_norm(self.get_ln_f()) + model.set_output_embeddings(self.get_lm_head()) + + active_blocks = [] + for block_idx in owned_block_indexes: + layer_replacement, block_idx_in_replacement = block_locations[block_idx] + block = self.get_block(layer_replacement, block_idx_in_replacement) + model.model.layers[block_idx] = block + active_blocks.append(block) + + self._move_inactive_blocks_to_cpu(active_blocks) + return model + + def load_checkpoint( + self, + checkpoint_dir: str | Path, + world_size: int, + global_rank: int, + ) -> DeciLMForCausalLM: + checkpoint_dir = Path(checkpoint_dir).resolve() + layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) + model = self.load_model(layer_replacements, world_size, global_rank) + return model + + def _locate_replacements_of_entire_checkpoint(self, checkpoint_dir: str | Path) -> list[dict]: + weight_paths_located = [] + layer_replacements = [] + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + weight_paths = [Path(p).absolute().resolve() for p in weight_paths] + layer_replacement["weight_paths"] = weight_paths + if len(weight_paths) > 0 and all( + p.is_relative_to(checkpoint_dir) for p in weight_paths + ): + layer_replacements.append(layer_replacement) + weight_paths_located.extend(weight_paths) + + all_block_weight_paths = [ + p + for p in list((checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).iterdir()) + if p.name not in ("embeddings.safetensors", "lm_head.safetensors") + ] + missing_paths = set(all_block_weight_paths) - set(weight_paths_located) + assert len(missing_paths) == 0, ( + f"Couldn't locate replacements for the entire checkpoint {checkpoint_dir}, missing weights: {missing_paths}" + ) + + dedupped_layer_replacements = [] + for weights_path in all_block_weight_paths: + replacements_with_path = [ + rep for rep in layer_replacements if weights_path in rep["weight_paths"] + ] + largets_replacement_with_path = max( + replacements_with_path, key=lambda rep: len(rep["weight_paths"]) + ) + if largets_replacement_with_path not in dedupped_layer_replacements: + dedupped_layer_replacements.append(largets_replacement_with_path) + + dedupped_layer_replacements = sort_replacements(dedupped_layer_replacements) + return dedupped_layer_replacements + + def get_block( + self, layer_replacement: dict, block_idx_in_replacement: int + ) -> DeciLMDecoderLayer | DeciLMMultiDecoderLayer: + if str(layer_replacement) not in self._loaded_replacements.keys(): + self._loaded_replacements[str(layer_replacement)] = self._load_layer_replacement( + layer_replacement + ) + module_list = self._loaded_replacements[str(layer_replacement)] + block = module_list[block_idx_in_replacement] + return block + + def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: + state_dict = dict() + for weights_path in layer_replacement["weight_paths"]: + if weights_path.suffix == ".safetensors": + curr_state_dict = safe_load_file(weights_path) + elif weights_path.suffix == ".pth": + curr_state_dict = torch.load(weights_path, weights_only=True) + else: + raise ValueError(f"Unrecognized suffix of {weights_path=}") + for param_name in curr_state_dict.keys(): + assert param_name not in state_dict, ( + f"Duplicate entries for {param_name=} in {layer_replacement=}" + ) + state_dict.update(curr_state_dict) + + if len(state_dict) > 0: + block_indices = [ + int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0]) + for param_name in state_dict.keys() + ] + assert sorted(set(block_indices)) == list( + range(min(block_indices), max(block_indices) + 1) + ), ( + f"Block indices in loaded weight files must be consecutive, but found {sorted(set(block_indices))} in {layer_replacement=}" + ) + + min_block_idx = min(block_indices) + + state_dict = { + param_name.replace( + f"model.layers.{block_idx}.", f"{block_idx - min_block_idx}." + ): param_weight + for block_idx, (param_name, param_weight) in zip(block_indices, state_dict.items()) + } + + dtype = infer_weights_dtype(state_dict) + model_config = self.model_config.set_block_configs(layer_replacement["child_block_configs"]) + + module_list = nn.ModuleList( + [ + ( + init_empty_module(DeciLMDecoderLayer, dtype, model_config, layer_idx) + if (block_config.parallel_blocks is None) + else init_empty_module(DeciLMMultiDecoderLayer, dtype, model_config, layer_idx) + ) + for layer_idx, block_config in enumerate(layer_replacement["child_block_configs"]) + ] + ) + + module_list.load_state_dict(state_dict, strict=True) + return module_list + + def _move_inactive_blocks_to_cpu(self, active_blocks: list[nn.Module]) -> None: + for module_list in self._loaded_replacements.values(): + for module in module_list: + if module not in active_blocks: + module.to("cpu") + + def get_embedding(self) -> nn.Embedding: + if self._embedding is None: + state_dict = { + "weight": self._get_arbitrary_non_block_param( + self.model_config.get_embedding_layer_name() + ".weight" + ) + } + self._embedding = init_module_with_state_dict( + state_dict, + nn.Embedding, + num_embeddings=self.model_config.vocab_size, + embedding_dim=self.model_config.hidden_size, + ) + return self._embedding + + def get_ln_f(self) -> DeciLMRMSNorm: + if self._ln_f is None: + state_dict = { + "weight": self._get_arbitrary_non_block_param( + self.model_config.get_final_layer_norm_layer_name() + ".weight" + ) + } + self._ln_f = init_module_with_state_dict( + state_dict, + DeciLMRMSNorm, + hidden_size=self.model_config.hidden_size, + eps=self.model_config.rms_norm_eps, + ) + return self._ln_f + + def get_lm_head(self) -> nn.Linear: + if self._lm_head is None: + state_dict = { + "weight": self._get_arbitrary_non_block_param( + self.model_config.get_lm_head_layer_name() + ".weight" + ) + } + self._lm_head = init_module_with_state_dict( + state_dict, + LMHead, + out_features=self.model_config.vocab_size, + in_features=self.model_config.hidden_size, + bias=False, + ) + return self._lm_head + + def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor: + checkpoint_dir = self.get_arbitrary_checkpoint_dir() + if ( + is_in_safetensors_format(checkpoint_dir) + or (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists() + ): + partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name]) + return partial_state_dict[param_name] + + non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth" + assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir) + non_block_state_dict = torch.load(non_block_pth_path) + return non_block_state_dict[param_name] + + def get_arbitrary_checkpoint_dir(self) -> Path: + if self._arbitrary_checkpoint_dir is None: + self._arbitrary_checkpoint_dir = self._get_arbitrary_checkpoint_dir() + return self._arbitrary_checkpoint_dir + + def get_teacher_dir(self) -> Path: + return self.teacher_dir + + def get_teacher_lm_head_path(self) -> Path: + return self.get_teacher_dir() / SAFETENSORS_SUBBLOCKS_DIR_NAME / "lm_head.safetensors" + + def get_teacher_embedding_path(self) -> Path: + return self.get_teacher_dir() / SAFETENSORS_SUBBLOCKS_DIR_NAME / "embeddings.safetensors" + + def _get_arbitrary_checkpoint_dir(self) -> Path: + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + if len(weight_paths) > 0: + return weights_path_to_checkpoint_dir(weight_paths[0]) + + def _get_all_checkpoint_dirs(self) -> list[Path]: + checkpoint_dirs = set() + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + for weights_path in weight_paths: + checkpoint_dir = weights_path_to_checkpoint_dir(weights_path) + checkpoint_dirs.add(checkpoint_dir) + return list(checkpoint_dirs) + + +def _error_message_ensure_split(checkpoint_dir: Path) -> str: + return ( + f"Encountered unsplit checkpoint dir '{checkpoint_dir}', " + f"please call `ensure_all_checkpoints_are_split`" + ) + + +def _get_owned_block_indexes(n_layer: int, world_size: int, global_rank: int) -> list[int]: + last_process_blocks = np.array([n_layer - 1]) # less params in last gpu, leave room for logits + + if world_size == 1: + # Only one process: assign everything (including the "last process" block) to rank 0 + owned_block_indexes_per_process = [ + np.concatenate([np.arange(n_layer - 1), last_process_blocks]) + ] + else: + # Multiple processes: split n_layer-1 blocks, reserve the last for "last process" + owned_block_indexes_per_process = np.array_split(range(n_layer - 1), world_size - 1) + owned_block_indexes_per_process.append(last_process_blocks) + + owned_block_indexes = owned_block_indexes_per_process[global_rank].tolist() + return owned_block_indexes diff --git a/modelopt/torch/_compress/replacement_library/replacement_utils.py b/modelopt/torch/_compress/replacement_library/replacement_utils.py new file mode 100644 index 0000000000..21ae411752 --- /dev/null +++ b/modelopt/torch/_compress/replacement_library/replacement_utils.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module provides helper functions for parsing, sorting, and analyzing layer replacement +configurations used in the replacement library for model compression. +""" + +# mypy: ignore-errors +import json +from copy import deepcopy +from pathlib import Path + +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig + + +def parse_layer_replacement(layer_replacement: dict | str) -> dict: + if isinstance(layer_replacement, str): + layer_replacement = json.loads(layer_replacement) + else: + layer_replacement = deepcopy(layer_replacement) + + if "layer_replacement" in layer_replacement: # happens in puzzle solutions + layer_replacement = layer_replacement["layer_replacement"] + + layer_replacement["child_block_configs"] = [ + BlockConfig(**block_config) if isinstance(block_config, dict) else block_config + for block_config in layer_replacement["child_block_configs"] + ] + layer_replacement["weight_paths"] = [Path(p) for p in layer_replacement["weight_paths"]] + return layer_replacement + + +def sort_replacements(layer_replacements: list[dict]) -> list[dict]: + return sorted(layer_replacements, key=lambda replacement: replacement["parent_layer_indices"]) + + +def extract_block_configs_and_locations( + layer_replacements: list[dict], +) -> tuple[list[BlockConfig], list[tuple[dict, int]]]: + layer_replacements = sort_replacements(layer_replacements) + block_configs = [] + block_locations = [] + for layer_replacement in layer_replacements: + child_block_configs = layer_replacement["child_block_configs"] + if not isinstance(child_block_configs, list | tuple): + child_block_configs = [child_block_configs] + for block_idx_in_replacement, block_config in enumerate(child_block_configs): + block_configs.append(block_config) + block_locations.append((layer_replacement, block_idx_in_replacement)) + return block_configs, block_locations + + +def weights_path_to_checkpoint_dir(weights_path: Path) -> Path: + checkpoint_dir: Path = weights_path + while checkpoint_dir != Path("/"): + if (checkpoint_dir / "config.json").exists(): + return checkpoint_dir + checkpoint_dir = checkpoint_dir.parent + raise FileNotFoundError(f"Couldn't find checkpoint dir for weights path {weights_path}") + + +def replacement_is_teacher( + layer_replacement: dict, + teacher_model_config: DeciLMConfig, + teacher_checkpoint_dir: Path, +) -> bool: + paths_all_teacher = all( + p.is_relative_to(teacher_checkpoint_dir) for p in layer_replacement["weight_paths"] + ) + return paths_all_teacher and is_replacement_identical_to_teacher( + layer_replacement, teacher_model_config + ) + + +def is_replacement_identical_to_teacher( + layer_replacement: dict, + teacher_model_config: DeciLMConfig, +) -> bool: + if len(layer_replacement["parent_layer_indices"]) == 1: + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_block_config = teacher_model_config.block_configs[block_idx] + if len(child_block_configs := layer_replacement["child_block_configs"]) == 1: + replacement_block_config: BlockConfig = child_block_configs[0] + if replacement_block_config == teacher_block_config: + return True + else: + parallel_blocks = getattr(replacement_block_config, "parallel_blocks", None) + if ( + parallel_blocks is not None + and len(parallel_blocks) == 1 + and parallel_blocks[0].attention == teacher_block_config.attention + and parallel_blocks[0].ffn == teacher_block_config.ffn + ): + return True + return False + + +def split_replacements_to_teacher_and_student( + replacements: list[dict], + teacher_model_config: DeciLMConfig, + teacher_checkpoint_dir: Path, +) -> tuple[list[dict], list[dict]]: + teacher_replacements, student_replacements = [], [] + for replacement in replacements: + if replacement_is_teacher(replacement, teacher_model_config, teacher_checkpoint_dir): + teacher_replacements.append(replacement) + else: + student_replacements.append(replacement) + return teacher_replacements, student_replacements diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py index 6e2ba9339a..74329bcd0a 100644 --- a/modelopt/torch/_compress/utils/utils.py +++ b/modelopt/torch/_compress/utils/utils.py @@ -13,8 +13,106 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses +from typing import Any + import torch +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: + """ + Convert a BlockConfig to a human-readable string representation. + + TODO: Consider a better place for this function. + Args: + block_config: BlockConfig dataclass or dict containing attention and ffn configs. + + Returns: + Formatted string with attention and FFN information, or None if input is None. + """ + if block_config is None: + return None + rep = "" + if dataclasses.is_dataclass(block_config): + block_config = dataclasses.asdict(block_config) + for subblock_name in ["attention", "ffn"]: + subblock_config = block_config[subblock_name] + rep += subblock_config_to_str(subblock_config, subblock_name) + return rep + + +def subblock_config_to_str( + subblock_config: FFNConfig | AttentionConfig | dict[str, Any] | None, + subblock_name: None | str = None, +) -> str | None: + """Convert a subblock config (FFN, Attention, Mamba, or MoE) to string. + + TODO: Consider a better place for this function. + Args: + subblock_config: FFNConfig, AttentionConfig dataclass or dict. + subblock_name: Name of subblock ('ffn', 'attention', 'mamba', 'moe'). + Auto-detected if subblock_config is a dataclass. + + Returns: + Formatted string showing subblock type and key parameters (e.g., intermediate_size, + n_heads_in_group), or None if input is None. + """ + if subblock_config is None: + return None + subblock_name = ( + "ffn" + if isinstance(subblock_config, FFNConfig) + else "mamba" + if isinstance(subblock_config, AttentionConfig) and subblock_config.is_mamba + else "attention" + if isinstance(subblock_config, AttentionConfig) + else subblock_name + ) + assert subblock_name is not None, "Must provide subblock_name if subblock_config is a dict." + + if dataclasses.is_dataclass(subblock_config): + subblock_config = dataclasses.asdict(subblock_config) + + if subblock_name == "attention" and subblock_config.get("mamba") is not None: + subblock_name = "mamba" + + if subblock_name == "ffn" and subblock_config.get("moe") is not None: + subblock_name = "moe" + + rep = f" {subblock_name}" + if subblock_config.get("no_op"): + rep += " no_op".ljust(8) + elif subblock_config.get("replace_with_linear"): + rep += " linear".ljust(8) + elif subblock_name == "ffn": + intermediate_size = subblock_config["intermediate_size"] + rep += f" intermediate_{intermediate_size}".ljust(8) + elif subblock_name == "attention": + n_heads_in_group = subblock_config["n_heads_in_group"] + rep += f" gqa_{n_heads_in_group}".ljust(8) + elif subblock_name == "mamba": + mamba_num_heads = subblock_config["mamba"]["num_heads"] + mamba_head_dim = subblock_config["mamba"]["head_dim"] + rep += f" num_heads_{mamba_num_heads} head_dim_{mamba_head_dim}".ljust(8) + elif subblock_name == "moe": + moe_num_local_experts = subblock_config["moe"]["num_local_experts"] + moe_expert_intermediate_dim = subblock_config["moe"]["expert_intermediate_dim"] + shared_expert_intermediate_dim = subblock_config["moe"]["shared_expert_intermediate_dim"] + num_experts_per_tok = subblock_config["moe"]["num_experts_per_tok"] + rep += f" num_experts_{moe_num_local_experts} expert_intermediate_dim_{moe_expert_intermediate_dim} shared_expert_intermediate_dim_{shared_expert_intermediate_dim} num_experts_per_tok_{num_experts_per_tok}".ljust( + 8 + ) + else: + raise ValueError(f"subblock_config_to_str: unrecognized subblock_name: {subblock_name}.") + + return rep + class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): def __init__(self, device=None, dtype=None): diff --git a/setup.py b/setup.py index ab70cdf68a..3eb41967d1 100644 --- a/setup.py +++ b/setup.py @@ -108,6 +108,8 @@ "wandb~=0.17.5", "lru-dict", "typeguard", + "pandas", + "immutabledict", ], } From 954103ed3342d0ede191d06c0d0108062b007b1a Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 1 Dec 2025 21:58:58 +0100 Subject: [PATCH 19/62] Add subblock stats to the compress algorithm (#623) ## What does this PR do? Add subblock stats to the compress algorithm. --------- Signed-off-by: Daniel Korzekwa --- .../_compress/build_library_and_stats.py | 2 +- .../calc_subblock_params_and_memory.py | 341 +++++++++++ .../subblock_stats/calc_subblock_stats.py | 554 ++++++++++++++++++ modelopt/torch/_compress/utils/utils.py | 50 ++ 4 files changed, 946 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py create mode 100644 modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py diff --git a/modelopt/torch/_compress/build_library_and_stats.py b/modelopt/torch/_compress/build_library_and_stats.py index 19bd4f03cc..f0735c98ff 100644 --- a/modelopt/torch/_compress/build_library_and_stats.py +++ b/modelopt/torch/_compress/build_library_and_stats.py @@ -30,12 +30,12 @@ """ import hydra -from calc_subblock_stats import launch_calc_subblock_stats from omegaconf import DictConfig from modelopt.torch._compress.replacement_library.build_replacement_library import ( launch_build_replacement_library, ) +from modelopt.torch._compress.subblock_stats.calc_subblock_stats import launch_calc_subblock_stats from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.utils.parsing import format_global_config diff --git a/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py new file mode 100644 index 0000000000..7f5a417786 --- /dev/null +++ b/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Calculate memory usage and parameter counts for neural network subblocks. + +This module provides utilities to compute memory footprints and parameter counts +for different subblock types (FFN, Attention, Mamba, MoE) in large language models, +considering various data types, batch sizes, and sequence lengths. +""" + +import json +import math +from pathlib import Path + +import numpy as np +import torch + +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + FFNConfig, + MambaConfig, +) +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMMoe +from modelopt.torch._compress.utils.utils import ( + calculate_kv_dim, + raise_unknown_subblock_config_error, + sizeof_dtype, +) + + +def calculate_subblock_memory( + subblock_config: FFNConfig | AttentionConfig, + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + weights_dtype: torch.dtype | str, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, +) -> float | dict[str, float]: + if subblock_config.no_op: + return 0 + if subblock_config.replace_with_linear: + return calculate_linear_memory(n_embd, weights_dtype) + if isinstance(subblock_config, FFNConfig): + return calculate_ffn_memory(subblock_config, n_embd, weights_dtype) + if isinstance(subblock_config, AttentionConfig): + if subblock_config.is_mamba: + return calculate_mamba_memory( + subblock_config.mamba, n_embd, batch_size, weights_dtype, kv_cache_dtype + ) + else: + return calculate_attention_memory( + subblock_config, + batch_size, + prefill_seq_len, + generation_seq_len, + prefill_queue_size, + n_embd, + n_head, + weights_dtype, + kv_cache_dtype, + allocate_prefill_query, + ) + raise_unknown_subblock_config_error(subblock_config) + + +def calculate_subblock_params( + subblock_config: FFNConfig | AttentionConfig, + n_embd: int, + n_head: int, +) -> int: + if subblock_config.no_op: + return 0 + if subblock_config.replace_with_linear: + return calculate_linear_params(n_embd) + if isinstance(subblock_config, FFNConfig): + return calculate_ffn_params(subblock_config, n_embd) + if isinstance(subblock_config, AttentionConfig): + if subblock_config.is_mamba: + return calculate_mamba_params(subblock_config.mamba, n_embd) + else: + return calculate_attention_params(subblock_config, n_embd, n_head) + raise_unknown_subblock_config_error(subblock_config) + + +def calc_subblock_active_params( + subblock_config: FFNConfig | AttentionConfig, + n_embd: int, + n_head: int, + moe_stats_file: str, + batch_size: int, + block_idx: int, +) -> int: + if not (isinstance(subblock_config, FFNConfig) and subblock_config.is_moe): + return calculate_subblock_params(subblock_config, n_embd, n_head) + else: + return estimate_moe_active_params( + subblock_config, n_embd, moe_stats_file, batch_size, block_idx + ) + + +def load_moe_stats(stats_file: str) -> dict: + with open(stats_file, "r") as f: + stats = json.load(f) + return [np.array(l) / np.sum(l) if len(l) > 0 else 0 for l in stats] + + +def estimate_num_active_experts( + dist_over_experts: np.ndarray, batch_size: int, num_experts: int +) -> int: + # cut the tail and renormalize + dist_over_experts = np.sort(dist_over_experts)[::-1][:num_experts] + dist_over_experts = dist_over_experts / (dist_over_experts.sum()) + # calculate the probability of at least one expert being active + # (expectation on indicators is the expected number of active experts) + return (1 - (1 - dist_over_experts) ** batch_size).sum() + + +def estimate_moe_active_params( + subblock_config: FFNConfig, + n_embd: int, + moe_stats_file: Path | str, + batch_size: int, + block_idx: int, +) -> int: + assert Path(moe_stats_file).exists() + # if not Path(moe_stats_file).exists(): # if path is not provided, should we assume uniform distribution? + # return calculate_subblock_params(subblock_config, n_embd, n_head=None) + moe_stats = load_moe_stats(moe_stats_file) + dist_over_experts = moe_stats[block_idx] + num_experts = subblock_config.moe.num_local_experts + + expected_num_active_experts = estimate_num_active_experts( + dist_over_experts, batch_size, num_experts + ) + expert_dim = subblock_config.moe.expert_intermediate_dim + shared_expert_dim = subblock_config.moe.shared_expert_intermediate_dim + num_linear_layers = 3 # all moe experts have 3 linear layers + + router_num_params = n_embd * num_experts + expected_num_active_experts_params = ( + num_linear_layers * expert_dim * n_embd * expected_num_active_experts + ) + shared_expert_num_params = num_linear_layers * shared_expert_dim * n_embd + + expected_total_params = ( + router_num_params + expected_num_active_experts_params + shared_expert_num_params + ) + return expected_total_params + + +def calculate_attention_memory( + attention_config: AttentionConfig, + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + weights_dtype: torch.dtype | str, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, +) -> dict[str, float]: + """ + allocate_prefill_query: infery-llm style. + Infery used a unified Wqkv matrix, so before extracting the kv-cache, + the query also had to be kept in-memory, once per layer. + """ + seq_len = prefill_seq_len + generation_seq_len + if ( + attention_config.is_llama4 + and (attention_chunk_size := attention_config.llama4.attention_chunk_size) is not None + ): + seq_len = min(seq_len, attention_chunk_size) + + kv_dim = calculate_kv_dim(attention_config.n_heads_in_group, n_head, n_embd) + total_num_tokens = seq_len * (batch_size + prefill_queue_size) + kv_cache_size = total_num_tokens * kv_dim + query_prefill_size = seq_len * n_embd if allocate_prefill_query else 0 + num_params = calculate_attention_params(attention_config, n_embd, n_head) + total_memory = ( + kv_cache_size * sizeof_dtype(kv_cache_dtype) + + query_prefill_size * sizeof_dtype(weights_dtype) + + num_params * sizeof_dtype(weights_dtype) + ) / 2**20 + kv_cache_memory = kv_cache_size * sizeof_dtype(kv_cache_dtype) / 2**20 + return {"memory_mib": total_memory, "kv_cache_memory_mib": kv_cache_memory} + + +def calculate_attention_params( + attention_config: AttentionConfig, + n_embd: int, + n_head: int, +) -> int: + kv_dim = calculate_kv_dim(attention_config.n_heads_in_group, n_head, n_embd) + return ( + n_embd * n_embd * 2 # Wq + Wo + + n_embd * kv_dim # Wk + Wv + + n_embd # rms norm + ) + + +def calculate_mamba_memory( + mamba_config: MambaConfig, + n_embd: int, + batch_size: int, + weights_dtype: torch.dtype | str, + kv_cache_dtype: torch.dtype | str, +) -> int: + return ( + calculate_mamba_params(mamba_config, n_embd) * sizeof_dtype(weights_dtype) + + calculate_mamba_state_size(mamba_config, batch_size) * sizeof_dtype(kv_cache_dtype) + ) / 2**20 + + +def calculate_mamba_params( + mamba_config: MambaConfig, + n_embd: int, +) -> int: + d_inner, in_proj_dim, conv_dim, kernel_size = _calculate_mamba_intermediates(mamba_config) + param_shapes = { + "A_log": (mamba_config.num_heads,), + "D": (mamba_config.num_heads,), + "conv1d.bias": (conv_dim,), + "conv1d.weight": (conv_dim, 1, kernel_size), + "dt_bias": (mamba_config.num_heads,), + "in_proj.weight": (in_proj_dim, n_embd), + "norm.weight": (d_inner,), + "out_proj.weight": (n_embd, d_inner), + } + mamba_mixer_params = sum([math.prod(shape) for shape in param_shapes.values()]) + rms_norm_params = n_embd + return mamba_mixer_params + rms_norm_params + + +def calculate_mamba_state_size( + mamba_config: MambaConfig, + batch_size: int, +) -> int: + d_inner, in_proj_dim, conv_dim, kernel_size = _calculate_mamba_intermediates(mamba_config) + conv_state_size = math.prod((batch_size, conv_dim, kernel_size)) + ssm_state_size = math.prod( + (batch_size, mamba_config.num_heads, mamba_config.head_dim, mamba_config.state_dim) + ) + return conv_state_size + ssm_state_size + + +def _calculate_mamba_intermediates(mamba_config: MambaConfig) -> tuple[int, ...]: + d_inner = mamba_config.num_heads * mamba_config.head_dim + in_proj_dim = ( + d_inner * 2 + 2 * mamba_config.num_groups * mamba_config.state_dim + mamba_config.num_heads + ) + conv_dim = d_inner + 2 * mamba_config.num_groups * mamba_config.state_dim + kernel_size = 4 + return d_inner, in_proj_dim, conv_dim, kernel_size + + +def calculate_linear_memory( + n_embd: int, + weights_dtype: torch.dtype | str, +) -> float: + return calculate_linear_params(n_embd) * sizeof_dtype(weights_dtype) / 2**20 + + +def calculate_linear_params( + n_embd: int, +) -> int: + return n_embd**2 + n_embd + + +def calculate_ffn_memory( + ffn_config: FFNConfig, + n_embd: int, + weights_dtype: torch.dtype | str, +) -> float: + num_params = calculate_ffn_params(ffn_config, n_embd) + return num_params * sizeof_dtype(weights_dtype) / 2**20 + + +def calculate_ffn_params( + ffn_config: FFNConfig, + n_embd: int, +) -> float: + if ffn_config.is_moe: + return calculate_moe_params(ffn_config, n_embd) + else: + return calculate_dense_ffn_params(ffn_config, n_embd) + + +def calculate_dense_ffn_params( + ffn_config: FFNConfig, + n_embd: int, +) -> int: + intermediate_size = ffn_config.intermediate_size + num_linear_layers = 3 if getattr(ffn_config, "gated", True) else 2 + rms_norm_params = n_embd + return n_embd * intermediate_size * num_linear_layers + rms_norm_params + + +def calculate_moe_params( + ffn_config: FFNConfig, + n_embd: int, +) -> int: + with torch.device("meta"): + config = DeciLMConfig(hidden_size=n_embd) + moe = DeciLMMoe(config, ffn_config) + moe_params = sum(p.numel() for p in moe.parameters()) + layernorm_params = n_embd + return moe_params + layernorm_params + + +def calculate_non_block_memory( + n_embd: int, + vocab_size: int, + weight_dtype: torch.dtype, +) -> float: + return calculate_non_block_params(n_embd, vocab_size) * sizeof_dtype(weight_dtype) / 2**20 + + +def calculate_non_block_params( + n_embd: int, + vocab_size: int, +) -> int: + return vocab_size * n_embd * 2 + n_embd diff --git a/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py b/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py new file mode 100644 index 0000000000..d3e73a0cf8 --- /dev/null +++ b/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py @@ -0,0 +1,554 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Calc subblock stats to compute memory and runtime statistics for subblocks.""" + +import os +from itertools import product + +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +import dataclasses +import json +from functools import partial +from pathlib import Path +from typing import Iterable, Optional, Type, TypeVar + +import hydra +import pandas as pd +import torch +from immutabledict import immutabledict +from omegaconf import DictConfig, ListConfig, OmegaConf +from tqdm import tqdm + +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + SubblockConfig, +) +from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement +from modelopt.torch._compress.subblock_stats.calc_subblock_params_and_memory import ( + calc_subblock_active_params, + calculate_non_block_memory, + calculate_non_block_params, + calculate_subblock_memory, + calculate_subblock_params, +) +from modelopt.torch._compress.tools.checkpoint_utils import load_model_config +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.robust_json import json_dump +from modelopt.torch._compress.utils.parsing import format_global_config + +# Type variable for dataclasses +T_DataClass = TypeVar("T_DataClass") + +""" +Usage: +python -m modelopt.torch._compress.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --benchmark_iterations 1000 ] + +--benchmark_iterations=None (the default) means that the code won't use infery to benchmark runtime, + only memory stats will be calculated. If you want to benchmark runtime, run inside an infery-llm docker. + +""" + + +def calculate_subblock_stats( + calc_subblock_stats_config: DictConfig, + teacher_dir: Path, + master_puzzle_dir: Path, + subblock_configs: list[immutabledict[str, AttentionConfig | FFNConfig]], + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + vocab_size: int, + benchmark_iterations: Optional[int], + use_cuda_graph: bool, + weights_dtype: torch.dtype, + activations_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, + moe_stats_file: str | Path | None = None, +) -> dict: + is_calc_runtime = benchmark_iterations is not None + if is_calc_runtime: + from puzzle_tools.subblock_stats.runtime_stats.calc_runtime_stats import ( + calc_runtime_ms_for_subblocks, + ) + + gpu = None if not torch.cuda.is_available() else torch.cuda.get_device_name() + subblock_stats = { + "args": dict( + is_calc_runtime=is_calc_runtime, + gpu=gpu, + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + generation_seq_len=generation_seq_len, + prefill_queue_size=prefill_queue_size, + n_embd=n_embd, + n_head=n_head, + vocab_size=vocab_size, + benchmark_iterations=benchmark_iterations, + use_cuda_graph=use_cuda_graph, + weights_dtype=str(weights_dtype), + activations_dtype=str(activations_dtype), + kv_cache_dtype=str(kv_cache_dtype), + ), + "non_block": dict(), + "subblocks": list(), + } + # Compute runtime stats for unique subblocks only + if is_calc_runtime: + subblock_configs_nolayerindex = set( + [subblock_config["subblock_config"] for subblock_config in subblock_configs] + ) + + # dict[SubblockConfig, float], float + # TODO: Manage default values for calc_subblock_stats_config in one place, e.g. within a dataclass for hydra config. + synth_dataset_num_requests = calc_subblock_stats_config.get("runtime_stats", {}).get( + "synth_dataset_num_requests", 200 + ) + backend = calc_subblock_stats_config.get("runtime_stats", {}).get("backend", "trt_torch") + runtime_by_subblock_dict, non_block_runtime_ms = calc_runtime_ms_for_subblocks( + subblock_configs_nolayerindex, + vocab_size, + n_embd, + n_head, + master_puzzle_dir, + teacher_dir, + synth_dataset_num_requests, + backend, + ) + + sorted_subblock_config = sorted( + subblock_configs, key=lambda subblock_config: subblock_config["subblock_config"] + ) + it = ( + tqdm(sorted_subblock_config, desc="Measuring subblock runtimes") + if is_calc_runtime + else sorted_subblock_config + ) + for subblock_config_indexed in it: + subblock_config = subblock_config_indexed["subblock_config"] + parent_layer_indices = subblock_config_indexed["parent_layer_indices"] + + if is_calc_runtime: + total_runtime_ms = runtime_by_subblock_dict[subblock_config] + prefill_runtime_ms = None + decode_runtime_ms = None + else: + total_runtime_ms, prefill_runtime_ms, decode_runtime_ms = None, None, None + + subblock_memory = calculate_subblock_memory( + subblock_config, + batch_size, + prefill_seq_len, + generation_seq_len, + prefill_queue_size, + n_embd, + n_head, + weights_dtype, + kv_cache_dtype, + allocate_prefill_query, + ) + if not isinstance(subblock_memory, dict): + subblock_memory = {"memory_mib": subblock_memory, "kv_cache_memory_mib": 0.0} + + subblock_params = calculate_subblock_params(subblock_config, n_embd, n_head) + if moe_stats_file is not None: + subblock_active_params = calc_subblock_active_params( + subblock_config, n_embd, n_head, moe_stats_file, batch_size, parent_layer_indices[0] + ) + else: + subblock_active_params = subblock_params + subblock_stats["subblocks"].append( + { + "subblock_config": subblock_config, + "subblock_config_class": type(subblock_config).__name__, + "runtime_ms": total_runtime_ms, + "prefill_runtime_ms": prefill_runtime_ms, + "decode_runtime_ms": decode_runtime_ms, + "num_params": subblock_params, + "active_params": subblock_active_params, + "parent_layer_index": parent_layer_indices[0], + **subblock_memory, + } + ) + + if is_calc_runtime: + pass + # TODO: fix + # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms + # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \ + # measure_non_block_runtime_ms(batch_size, prefill_seq_len, generation_seq_len, n_embd, vocab_size, + # benchmark_iterations, use_cuda_graph) + embedding_runtime_ms, lm_head_runtime_ms = None, None + else: + non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = None, None, None + non_block_memory = calculate_non_block_memory(n_embd, vocab_size, weights_dtype) + non_block_params = calculate_non_block_params(n_embd, vocab_size) + + # TODO + # the semantics here is wrong why do we refer, prefill_runtime_ms as embedding_runtime_ms and lm_head_runtime_ms as decode_runtime_ms ? + # Prefill is the first the user prompt inference, and Decode refer to the next generation process. both processes use all the model layers. + subblock_stats["non_block"] = { + "runtime_ms": non_block_runtime_ms, + "prefill_runtime_ms": embedding_runtime_ms, + "decode_runtime_ms": lm_head_runtime_ms, + "memory_mib": non_block_memory, + "num_params": non_block_params, + } + return subblock_stats + + +def launch_calc_subblock_stats(cfg: DictConfig) -> None: + """ + Launch the calc subblock stats function with Hydra configuration. + """ + mprint(f"Calculating subblock stats for puzzle directory: {cfg.puzzle_dir}") + mprint(f"Teacher directory: {cfg.teacher_dir}") + mprint( + f"Calc subblock stats config: {format_global_config(cfg.calc_subblock_stats, title='Calc subblock stats')}" + ) + + calculate_subblock_stats_for_puzzle_dir( + cfg.calc_subblock_stats, + master_puzzle_dir=cfg.puzzle_dir, + teacher_dir=cfg.teacher_dir, + model_hidden_sizes=cfg.calc_subblock_stats.get("model_hidden_sizes", OmegaConf.create([])), + ffn_hidden_sizes=cfg.calc_subblock_stats.get("ffn_hidden_sizes", OmegaConf.create([])), + batch_sizes=cfg.calc_subblock_stats.batch_sizes, + prefill_seq_len=cfg.calc_subblock_stats.prefill_seq_len, + generation_seq_len=cfg.calc_subblock_stats.generation_seq_len, + num_active_tokens_override=cfg.calc_subblock_stats.get("num_active_tokens_override", None), + prefill_queue_size=cfg.calc_subblock_stats.prefill_queue_size, + allocate_prefill_query=cfg.calc_subblock_stats.allocate_prefill_query, + benchmark_iterations=cfg.calc_subblock_stats.get("benchmark_iterations", None), + merge_with_existing_stats=cfg.calc_subblock_stats.merge_with_existing_stats, + subblock_stats_filename=cfg.calc_subblock_stats.subblock_stats_filename, + moe_stats_filename=cfg.calc_subblock_stats.moe_stats_filename, + ) + + +def calculate_subblock_stats_for_puzzle_dir( + calc_subblock_stats_config: DictConfig, + master_puzzle_dir: Path | str, + teacher_dir: Path | str, + model_hidden_sizes: ListConfig, + ffn_hidden_sizes: ListConfig, + batch_sizes: Iterable[int] = (1, 8, 16, 32, 64, 128, 256), + prefill_seq_len: int = 2048, + generation_seq_len: int = 2048, + num_active_tokens_override: int | None = None, + prefill_queue_size: int = 0, # it's an infery-llm thing + allocate_prefill_query: bool = False, + benchmark_iterations: ( + int | None + ) = None, # If set then compute runtime performance statistics. TODO: recommend default value, is 1000 good? + merge_with_existing_stats: bool = False, + subblock_stats_filename: str = "subblock_stats.json", + moe_stats_filename: str = "moe_stats.json", +) -> None: + # ==== START === Setup for attach-helper ==== + # import sys + # import os + # sys.path.insert(0, os.environ["ATTACH_HELPER_INSTALLATION_PATH"]) + # from attach_helper import debugging_setup + # debugging_setup() # You can optionally pass a name to identify the job (e.g. `debugging_setup(name="my_script")`) + # ==== END === Setup for attach-helper ==== + if isinstance(batch_sizes, str): + batch_sizes = [ + int(batch_size) for batch_size in batch_sizes.strip("[]").replace(" ", "").split(",") + ] + + master_puzzle_dir = Path(master_puzzle_dir) + teacher_dir = ( + Path(teacher_dir) if teacher_dir is not None else master_puzzle_dir / "ckpts" / "teacher" + ) + model_config = load_model_config(teacher_dir) + subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes, model_config) + + subblock_stats_file = master_puzzle_dir / subblock_stats_filename + if subblock_stats_file.exists() and not merge_with_existing_stats: + raise ValueError( + f"Subblock stats file {subblock_stats_file} already exists and `merge_with_existing_stats` was set to False." + ) + + if subblock_stats_file.exists(): + with open(subblock_stats_file) as f: + subblock_stats = json.load(f) + else: + subblock_stats = [] + + moe_stats_file = master_puzzle_dir / moe_stats_filename + if not moe_stats_file.exists(): + Warning( + f"MOE stats file {moe_stats_file} does not exist, can't calculate num active params" + ) + moe_stats_file = None + + subblock_stats_args = {immutabledict(x["args"]) for x in subblock_stats} + + data_types = [ + ("nvfp4", "nvfp4", "nvfp4"), + (torch.int8, torch.int8, torch.int8), + (torch.int8, torch.int8, torch.bfloat16), + (torch.bfloat16, torch.bfloat16, torch.bfloat16), + ] + + model_hidden_sizes = model_hidden_sizes + [ + model_config.hidden_size + ] # add a teacher model hidden size + for batch_size, ( + weights_dtype, + activations_dtype, + kv_cache_dtype, + ), model_hidden_size in product(batch_sizes, data_types, model_hidden_sizes): + if num_active_tokens_override is not None: + prefill_seq_len = generation_seq_len = int(num_active_tokens_override / batch_size / 2) + + curr_benchmark_iterations = ( + benchmark_iterations if weights_dtype == torch.bfloat16 else None + ) + + curr_subblock_stats = calculate_subblock_stats( + calc_subblock_stats_config, + teacher_dir=teacher_dir, + master_puzzle_dir=master_puzzle_dir, + subblock_configs=subblock_configs, + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + generation_seq_len=generation_seq_len, + prefill_queue_size=prefill_queue_size, + n_embd=model_hidden_size, + n_head=model_config.num_attention_heads, + vocab_size=model_config.vocab_size, + benchmark_iterations=curr_benchmark_iterations, + use_cuda_graph=True, + weights_dtype=weights_dtype, + activations_dtype=activations_dtype, + kv_cache_dtype=kv_cache_dtype, + allocate_prefill_query=allocate_prefill_query, + moe_stats_file=moe_stats_file, + ) + + if immutabledict(curr_subblock_stats["args"]) in subblock_stats_args: + raise ValueError( + f"Failed merging subblock_stats. The following arguments already existed in the file: {curr_subblock_stats['args']}" + ) + + subblock_stats.append(curr_subblock_stats) + + # TODO fix: add_int8_runtime_estimates(subblock_stats) + + json_dump(subblock_stats, subblock_stats_file) + + mprint(subblock_stats_file) + + +def _load_subblock_configs( + master_puzzle_dir: Path, ffn_hidden_sizes: ListConfig, model_config: DeciLMConfig +) -> list[SubblockConfig]: + try: + subblock_configs = _load_subblock_configs_from_replacement_library(master_puzzle_dir) + except FileNotFoundError: + subblock_configs = _load_subblock_configs_from_subblock_library(master_puzzle_dir) + + # Extend subblock stats calculation space with ffn_hidden_sizes defined in the calc_subblock_stats section of the model config yaml file. + extra_ffn_subblock_configs = [] + for ffn_hidden_size in ffn_hidden_sizes: + # Use FFNConfig defaults (hidden_act will use its default value) + ffn_config = FFNConfig(intermediate_size=ffn_hidden_size) + extra_ffn_subblock_configs.append( + immutabledict({"subblock_config": ffn_config, "parent_layer_indices": tuple([-1])}) + ) # -1 to indicate that this sublock has no parent layer + subblock_configs.extend(extra_ffn_subblock_configs) + + return subblock_configs + + +def _load_subblock_configs_from_subblock_library(master_puzzle_dir: Path) -> list[SubblockConfig]: + subblocks_df = pd.read_json(master_puzzle_dir / "subblock_library.json") + subblocks_df["attention_config"] = subblocks_df["attention_config"].apply( + partial(_dataclass_from_dict, cls=AttentionConfig) + ) + subblocks_df["ffn_config"] = subblocks_df["ffn_config"].apply( + partial(_dataclass_from_dict, cls=FFNConfig) + ) + attention_configs = subblocks_df["attention_config"].dropna().drop_duplicates().tolist() + ffn_configs = subblocks_df["ffn_config"].dropna().drop_duplicates().tolist() + subblock_configs = attention_configs + ffn_configs + return subblock_configs + + +def _load_subblock_configs_from_replacement_library( + master_puzzle_dir: Path, +) -> list[SubblockConfig]: + """Load unique subblocks from replacement_library.json, e.g., + 256 = 32*8 unique sublocks will be returned for a model with 32 layers and the search space of + 4 intermediate_size + teacher_intermediate_size + ffn_noop + att_op (teacher) + att_noop. + + Args: + master_puzzle_dir (Path): Directory with "replacement_library.json" file + + Returns: + list[SubblockConfig]: + """ + replacement_library = json.loads((master_puzzle_dir / "replacement_library.json").read_text()) + subblock_configs = set() + for layer_replacement in replacement_library: + layer_replacement = parse_layer_replacement(layer_replacement) + + for block_config in layer_replacement["child_block_configs"]: + block_config: BlockConfig + attention_frozen_dict = immutabledict( + { + "subblock_config": block_config.attention, + "parent_layer_indices": tuple(layer_replacement["parent_layer_indices"]), + } + ) + ffn_frozen_dict = immutabledict( + { + "subblock_config": block_config.ffn, + "parent_layer_indices": tuple(layer_replacement["parent_layer_indices"]), + } + ) + subblock_configs.add(attention_frozen_dict) + subblock_configs.add(ffn_frozen_dict) + + if block_config.parallel_blocks is not None: + for block_idx, internal_block_config in enumerate(block_config.parallel_blocks): + attention_frozen_dict = immutabledict( + { + "subblock_config": internal_block_config.attention, + "parent_layer_indices": tuple( + layer_replacement["parent_layer_indices"] + ), + "inner_block_idx": block_idx, + } + ) + ffn_frozen_dict = immutabledict( + { + "subblock_config": internal_block_config.ffn, + "parent_layer_indices": tuple( + layer_replacement["parent_layer_indices"] + ), + "inner_block_idx": block_idx, + } + ) + subblock_configs.add(attention_frozen_dict) + subblock_configs.add(ffn_frozen_dict) + + subblock_configs = list(subblock_configs) + return subblock_configs + + +T_DataClass: TypeVar = Type[dataclasses.dataclass] + + +def _dataclass_from_dict( + d: dict | T_DataClass | None, + cls: T_DataClass, +) -> T_DataClass | None: + if isinstance(d, cls): + return d + if isinstance(d, dict): + return cls(**d) + if pd.isna(d): + return None + raise ValueError(f"_dataclass_from_dict: unrecognized {type(d)=} {d=}") + + +def add_int8_runtime_estimates(subblock_stats: list[dict]) -> None: + for curr_subblock_stats in subblock_stats: + args = curr_subblock_stats["args"] + if args["weights_dtype"] == "torch.int8": + assert args["activations_dtype"] == "torch.int8" + ffn_factor = 0.5 + attention_factor = 0.5 if args["kv_cache_dtype"] == "torch.int8" else 0.8 + + bf16_stats = _find_corresponding_bf16_stats(args, subblock_stats) + if bf16_stats is not None: + curr_subblocks = curr_subblock_stats["subblocks"] + [ + curr_subblock_stats["non_block"] + ] + bf16_subblocks = bf16_stats["subblocks"] + [bf16_stats["non_block"]] + for curr_subblock, bf16_subblock in zip(curr_subblocks, bf16_subblocks): + assert curr_subblock.get("subblock_config", None) == bf16_subblock.get( + "subblock_config", None + ) + is_attention = False + if (subblock_config := curr_subblock.get("subblock_config")) is not None: + if hasattr(subblock_config, "__dataclass_fields__"): + subblock_config = dataclasses.asdict(subblock_config) + is_attention = subblock_config.get("n_heads_in_group", None) is not None + runtime_factor = attention_factor if is_attention else ffn_factor + for stat_name, stat_value in bf16_subblock.items(): + if "runtime" in stat_name: + curr_subblock[stat_name] = stat_value * runtime_factor + + +def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> dict | None: + scenario_keys = [ + "batch_size", + "prefill_seq_len", + "generation_seq_len", + "prefill_queue_size", + "gpu", + "n_embd", + "n_head", + "vocab_size", + ] + corresponding_bf16_args = { + **{k: v for k, v in args.items() if k in scenario_keys}, + "is_calc_runtime": True, + "weights_dtype": "torch.bfloat16", + "activations_dtype": "torch.bfloat16", + "kv_cache_dtype": "torch.bfloat16", + } + matching_bf16_stats = [ + stats + for stats in subblock_stats + if all( + [ + stats["args"][key] == corresponding_bf16_args[key] + for key in corresponding_bf16_args.keys() + ] + ) + ] + if len(matching_bf16_stats) == 0: + return None + if len(matching_bf16_stats) == 1: + return matching_bf16_stats[0] + raise ValueError(f"Found more than 1 matching bf16 stats for {args=}") + + +@hydra.main("configs", version_base="1.3", config_name="search_space") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(format_global_config(cfg)) + launch_calc_subblock_stats(cfg) + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py index 74329bcd0a..7acb3f3274 100644 --- a/modelopt/torch/_compress/utils/utils.py +++ b/modelopt/torch/_compress/utils/utils.py @@ -25,6 +25,56 @@ ) +def calculate_kv_dim(n_heads_in_group: int, n_head: int, n_embd: int) -> int: + """Calculate the key-value dimension for grouped-query attention. + + TODO: Consider a better place for this function. + Args: + n_heads_in_group: Number of attention heads per key-value group. + n_head: Total number of attention heads. + n_embd: Embedding dimension. + + Returns: + Combined dimension for key and value tensors (2 * n_kv_heads * head_size). + """ + if n_heads_in_group is None: + return 0 + n_kv_heads = n_head // n_heads_in_group + head_size = n_embd // n_head + kv_dim = 2 * n_kv_heads * head_size + return kv_dim + + +def raise_unknown_subblock_config_error(subblock_config: Any) -> None: + """Raise an error for invalid subblock configuration types. + + TODO: Consider a better place for this function. + Args: + subblock_config: The invalid subblock configuration object. + + Raises: + ValueError: Always raised with a message indicating the expected types. + """ + raise ValueError( + f"subblock_config should be an instance of FFNConfig or AttentionConfig, instead got {type(subblock_config)}" + ) + + +def sizeof_dtype(dtype: torch.dtype | str) -> int | float: + """Return the size in bytes of the given data type. + + TODO: Consider a better place for this function. + Args: + dtype: PyTorch data type or custom type string (e.g., 'nvfp4'). + + Returns: + Size in bytes of the data type. Special case: 'nvfp4' returns ~0.588 bytes. + """ + if dtype == "nvfp4": + return 1 / 1.7 + return torch.tensor([], dtype=dtype).element_size() + + def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: """ Convert a BlockConfig to a human-readable string representation. From dcc425f1f36c1693f5be94bdd2c186457028ffb2 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 2 Dec 2025 13:56:38 +0100 Subject: [PATCH 20/62] Add 1-block scoring to the compress algorithm (#625) ## What does this PR do? Add 1-block scoring to the compress algorithm. --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Daniel Korzekwa Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../nas/plugins/compress_nas_plugin.py | 2 +- modelopt/torch/_compress/scoring/scoring.py | 100 ++++++ .../torch/_compress/tools/validate_model.py | 5 +- ...validate_puzzle_with_multi_replacements.py | 330 ++++++++++++++++++ .../torch/_compress/tools/validation_utils.py | 120 +++++++ 5 files changed, 555 insertions(+), 2 deletions(-) create mode 100644 modelopt/torch/_compress/scoring/scoring.py create mode 100644 modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py create mode 100644 modelopt/torch/_compress/tools/validation_utils.py diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 72c40f729f..390ba835a7 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -24,11 +24,11 @@ from pathlib import Path import mip_and_realize_models -import scoring import torch from torch import nn import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch._compress.scoring.scoring as scoring from modelopt.torch._compress import build_library_and_stats from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( diff --git a/modelopt/torch/_compress/scoring/scoring.py b/modelopt/torch/_compress/scoring/scoring.py new file mode 100644 index 0000000000..f17b8cd3e3 --- /dev/null +++ b/modelopt/torch/_compress/scoring/scoring.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validates and scores model compression solutions by evaluating puzzle solution candidates.""" + +# mypy: ignore-errors +import os +import re +from glob import glob +from pathlib import Path + +import hydra +import numpy as np +import pandas as pd +import torch +from omegaconf import DictConfig + +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import BaseRuntime, IRuntime, NativeDdpRuntime +from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import ( + validate_puzzle_solutions, +) +from modelopt.torch._compress.utils.dist_utils import is_distributed + + +def extract_solution_id(filename): + pattern = r"solution_(\d+)\.json" + match = re.search(pattern, filename) + + if match: + solution_id = match.group(1) + return int(solution_id) + else: + mprint(f"Couldn't extract solutions_id from file {filename}") + + +def find_missing_solutions(solutions_df, validation_dir): + all_solutions = np.arange(solutions_df.shape[0]) + + benchmarked_solutions = list(glob(f"{validation_dir}/solution*.json")) + benchmarked_solutions = [ + extract_solution_id(os.path.basename(s)) for s in benchmarked_solutions + ] + benchmarked_solutions = [s for s in benchmarked_solutions if s is not None] + + unbenchmarked_solutions = np.setdiff1d(all_solutions, benchmarked_solutions) + return unbenchmarked_solutions.tolist() + + +def get_solutions_to_validate(cfg: DictConfig): + _solutions_to_validate = cfg.scoring.solutions_to_validate + if _solutions_to_validate is None: + single_block_replacement_solutions = pd.read_json(cfg.scoring.solutions_path) + if cfg.scoring.skip_existing_solutions: + _solutions_to_validate = find_missing_solutions( + single_block_replacement_solutions, cfg.scoring.output_dir + ) + else: + _solutions_to_validate = np.arange(single_block_replacement_solutions.shape[0]).tolist() + return _solutions_to_validate + + +def launch_scoring(cfg: DictConfig, runtime: IRuntime): + cfg.scoring.solutions_to_validate = get_solutions_to_validate(cfg) + mprint(f"Solutions to validate: {cfg.scoring.solutions_to_validate}") + validate_puzzle_solutions(args=cfg.scoring, runtime=runtime) + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(cfg) + + _runtime = ( + NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") + ) + if is_distributed() + else BaseRuntime(dtype=torch.bfloat16) + ) + with _runtime as runtime: + launch_scoring(cfg, runtime) + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index 47e8e4202d..8ec1d6f172 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -16,7 +16,10 @@ """ Provides a function to validate a model. Runs a model forward pass on a dataset and calculates the loss, and optionally registers hooks to capture the inputs and the outputs -of pytorch modules that are used for activation scoring for pruning.""" +of pytorch modules that are used for activation scoring for pruning. + +TODO: Consider moving this a separate module dedicated for scoring. +""" import argparse import textwrap diff --git a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py new file mode 100644 index 0000000000..e947e97e4e --- /dev/null +++ b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validates puzzle solutions by applying layer replacements and evaluating model performance. + +TODO: Consider moving this a separate module dedicated for scoring. +""" + +# mypy: ignore-errors + +import argparse +import json +import shutil +import warnings +from functools import partial +from pathlib import Path +from typing import Optional + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.replacement_library.build_replacement_library import infer_teacher_dir +from modelopt.torch._compress.replacement_library.replacement_library import ReplacementLibrary +from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement +from modelopt.torch._compress.tools import validate_model +from modelopt.torch._compress.tools.checkpoint_utils import ( + SAFETENSORS_SUBBLOCKS_DIR_NAME, + copy_tokenizer, +) +from modelopt.torch._compress.tools.checkpoint_utils_hf import ( + save_checkpoint, + save_safetensors_index, +) +from modelopt.torch._compress.tools.runtime import IRuntime +from modelopt.torch._compress.tools.validation_utils import ( + validate_model_and_extract_hidden_states, + validate_model_with_teacher_similarity_metrics, +) +from modelopt.torch._compress.utils.parsing import get_nested_key, parse_path +from modelopt.torch._compress.utils.validate_runtime_pipeline import perform_pipeline_stitches + +""" +Usage: +====== + +Validate single_block_replacement_solutions +=========================================== + +( +export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"; +PUZZLE_DIR=".../Llama-3_2-1B-Instruct/parallel_puzzle"; + +torchrun --nproc-per-node=8 \ + -m modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements \ + --replacement_library_path ${PUZZLE_DIR}/replacement_library.json \ + --solutions_path ${PUZZLE_DIR}/single_sequence_replacement_solutions.json \ + --solutions_to_validate 0 \ + \ + --dataset_path .../v0.4/valid \ + --data_column conversation --block_size 8192 --seed 42 --shuffle_seed 444 --bos_rate 0.5 \ + --eval_samples 32 --micro_batch_size 1 \ + \ + --save_models \ + +) + + +""" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--replacement_library_path", type=parse_path, required=True) + parser.add_argument("--solutions_path", type=parse_path, required=True) + parser.add_argument("--teacher_dir", type=parse_path, default=None) + parser.add_argument("--solutions_to_validate", type=int, nargs="+", default=None) + parser.add_argument("--sort_solutions_by", type=str, default=None) + parser.add_argument("--bigger_is_better", action="store_true") + parser.add_argument("--skip_validation", action="store_true") + parser.add_argument("--save_models", action="store_true") + args, unknown_args = parser.parse_known_args() + if not args.skip_validation: + validation_args = validate_model.build_arg_parser().parse_args(unknown_args) + args = argparse.Namespace( + **{**validation_args.__dict__, **args.__dict__} + ) # if arg names overlap, the latter one wins + else: + args.block_size = None + + args.teacher_dir = _try_infer_teacher_dir(args.replacement_library_path, args.teacher_dir) + + args.tokenizer_name = getattr(args, "tokenizer_name", None) + if args.tokenizer_name is None: + args.tokenizer_name = args.teacher_dir + + return args + + +@torch.no_grad() +def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> None: + puzzle_solutions = load_puzzle_solutions( + args.solutions_path, args.sort_solutions_by, args.bigger_is_better + ) + if args.solutions_to_validate is None: + args.solutions_to_validate = list(range(len(puzzle_solutions))) + puzzle_solutions = [puzzle_solutions[i] for i in args.solutions_to_validate] + + tokenizer = _load_tokenizer(args) + if not args.skip_validation: + val_dataloader = ( + validate_model.prepare_dataloader(args, tokenizer) + if (runtime is None or runtime.is_main_process) + else None + ) + + output_dir = ( + args.output_dir + if getattr(args, "output_dir", None) is not None + else args.solutions_path.with_name(f"{args.solutions_path.stem}--validation") + ) + + replacement_library = ReplacementLibrary(args.replacement_library_path) + + teacher_hidden_states = None + if (args.teacher_dir is not None) and (not args.skip_validation): + teacher_model = replacement_library.load_checkpoint( + args.teacher_dir, runtime.world_size, runtime.global_rank + ) + teacher_model.to(runtime.device) + stitched_model = perform_pipeline_stitches(teacher_model, runtime) + teacher_hidden_states = validate_model_and_extract_hidden_states( + args, + stitched_model, + tokenizer, + output_dir, + model_name="teacher", + runtime=runtime, + val_dataloader=val_dataloader, + ) + + for i_solution, puzzle_solution in tqdm( + list(zip(args.solutions_to_validate, puzzle_solutions)), desc="Validating solutions" + ): + layer_replacements = _extract_layer_replacements_from_puzzle_solution(puzzle_solution) + realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) + # realizable_as_symlinks = False + model_config = replacement_library.create_model_config(layer_replacements) + if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): + model = replacement_library.load_model( + layer_replacements, runtime.world_size, runtime.global_rank + ) + model_config = model.config + + if args.save_models: + checkpoint_dir = ( + args.solutions_path.with_name(f"{args.solutions_path.stem}--checkpoints") + / f"solution_{i_solution}" + ) + + model_config.dtype = "bfloat16" + model_config.architectures = ["DeciLMForCausalLM"] + if realizable_as_symlinks: + if runtime.global_rank == 0: + save_checkpoint_as_symlinks( + layer_replacements, model_config, checkpoint_dir, replacement_library + ) + else: + save_checkpoint(model, checkpoint_dir) + + copy_tokenizer(args.tokenizer_name, checkpoint_dir) + copy_hf_code(checkpoint_dir) + + runtime.wait_for_everyone() + + runtime.wait_for_everyone() + + if not args.skip_validation: + model.to(runtime.device) + stitched_model = perform_pipeline_stitches(model, runtime) + validate_model_with_teacher_similarity_metrics( + args, + stitched_model, + tokenizer, + teacher_hidden_states, + output_dir, + model_name=f"solution_{i_solution}", + extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution}, + runtime=runtime, + val_dataloader=val_dataloader, + ) + + runtime.wait_for_everyone() + + +def can_realize_as_symlinks(layer_replacements: list[dict]) -> bool: + for layer_replacement in layer_replacements: + num_parent_layers = len(layer_replacement["parent_layer_indices"]) + num_child_layers = len(layer_replacement["child_block_configs"]) + if num_parent_layers != num_child_layers or num_parent_layers != 1: + return False + return True + + +def force_create_symlink(src: Path, dst: Path) -> None: + if dst.exists(): + dst.unlink() + dst.symlink_to(src) + + +def save_checkpoint_as_symlinks( + layer_replacements: list[dict], + model_config: DeciLMConfig, + checkpoint_dir: Path, + replace_library: ReplacementLibrary, +) -> None: + model_config.save_pretrained(checkpoint_dir) + (checkpoint_dir / "subblocks_safetensors").mkdir(parents=True, exist_ok=True) + save_safetensors_index(model_config, checkpoint_dir) + + for layer_replacement in layer_replacements: + for weight_path in layer_replacement["weight_paths"]: + force_create_symlink( + weight_path, checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME / weight_path.name + ) + + lm_head_path = replace_library.get_teacher_lm_head_path() + force_create_symlink( + lm_head_path, checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME / lm_head_path.name + ) + + embedding_path = replace_library.get_teacher_embedding_path() + force_create_symlink( + embedding_path, checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME / embedding_path.name + ) + + +def copy_hf_code(checkpoint_dir: Path) -> None: + code_dir = Path(__file__).parent / "deci_lm_hf_code" + print(f"copying hf code from {code_dir} ") + for file in code_dir.glob("*.py"): + shutil.copy(file, checkpoint_dir / file.name) + + +def _try_infer_teacher_dir( + replacement_library_path: str | Path, + teacher_dir: str | Path | None, +) -> Path | None: + if teacher_dir is not None: + return teacher_dir + + try: + teacher_dir = infer_teacher_dir( + master_puzzle_dir=Path(replacement_library_path).parent, teacher_checkpoint_dir=None + ) + return teacher_dir + except: + return None + + +def _load_tokenizer(args: argparse.Namespace) -> PreTrainedTokenizerBase: + tokenizer = None + if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + elif args.teacher_dir is not None: + try: + tokenizer = AutoTokenizer.from_pretrained(args.teacher_dir, trust_remote_code=True) + except: + pass + if tokenizer is None: + warnings.warn("Couldn't find a tokenizer, trying to continue without one") + return tokenizer + + +def _extract_layer_replacements_from_puzzle_solution( + puzzle_solution: dict, +) -> list[dict]: + puzzle_solution = puzzle_solution.get("puzzle_solution", puzzle_solution) + layer_replacements = [ + parse_layer_replacement(rep) for rep in puzzle_solution["chosen_replacements"] + ] + return layer_replacements + + +def load_puzzle_solutions( + solutions_path: Path, + sort_solutions_by: Optional[str], + bigger_is_better: bool, +) -> list[dict]: + assert solutions_path.exists(), f"{solutions_path=} does not exist" + + if solutions_path.is_file(): + puzzle_solutions = json.loads(solutions_path.read_text()) + if isinstance(puzzle_solutions, dict): + puzzle_solutions = [puzzle_solutions] + else: + puzzle_solutions = [ + json.loads(p.read_text()) for p in solutions_path.glob("*solution*.json") + ] + + if len(puzzle_solutions) == 0: + raise ValueError(f"No solutions under {solutions_path=}") + + if sort_solutions_by is not None: + puzzle_solutions = sorted( + puzzle_solutions, key=partial(get_nested_key, field=sort_solutions_by) + ) + if bigger_is_better: + puzzle_solutions = puzzle_solutions[::-1] + vals = [get_nested_key(sol, sort_solutions_by) for sol in puzzle_solutions] + print(f"sorted solutions by {sort_solutions_by}. {vals[:10]=} {vals[-10:]=}") + + return puzzle_solutions + + +if __name__ == "__main__": + validate_puzzle_solutions(args=parse_args()) diff --git a/modelopt/torch/_compress/tools/validation_utils.py b/modelopt/torch/_compress/tools/validation_utils.py new file mode 100644 index 0000000000..907dee4029 --- /dev/null +++ b/modelopt/torch/_compress/tools/validation_utils.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for validating models and extracting hidden states and similarity metrics. + +TODO: Consider moving this a separate module dedicated for scoring. +""" + +# mypy: ignore-errors + +import argparse +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf +from torch import nn +from transformers import PreTrainedTokenizerBase + +from modelopt.torch._compress.sewing_kit import StitchedModule +from modelopt.torch._compress.tools import validate_model +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.robust_json import json_dump +from modelopt.torch._compress.tools.runtime import IRuntime +from modelopt.torch._compress.utils.validation import LowMemorySparseTensor + + +def validate_model_and_extract_hidden_states( + args: argparse.Namespace, + model: nn.Module | StitchedModule, + tokenizer: PreTrainedTokenizerBase, + output_dir: Union[str, Path], + model_name: str, + extra_payload: Optional[dict[str, Any]] = None, + runtime: Optional[IRuntime] = None, + val_dataloader=None, +) -> list[torch.Tensor | LowMemorySparseTensor]: + mprint(f""" + +################################################################ +validate_model_and_extract_token_probs({model_name=}) +################################################################ + +""") + losses, hidden_states_per_batch = validate_model.validate_model( + args, + model, + tokenizer, + return_hidden_states=True, + runtime=runtime, + val_dataloader=val_dataloader, + ) + if runtime is None or runtime.is_last_process: + output_dir = output_dir if (output_dir is not None) else args.bypass_dir + extra_payload = extra_payload if (extra_payload is not None) else dict() + write_results(output_dir, model_name, args, {**losses, **extra_payload}) + return hidden_states_per_batch + + +def validate_model_with_teacher_similarity_metrics( + args: argparse.Namespace, + model: nn.Module | StitchedModule, + tokenizer: PreTrainedTokenizerBase, + target_hidden_states_per_batch: list[torch.Tensor], + output_dir: Union[str, Path], + model_name: str, + extra_payload: Optional[dict[str, Any]] = None, + runtime: Optional[IRuntime] = None, + calculate_full_score_ablations: bool = False, + val_dataloader=None, +) -> None: + is_calc_kl_div = target_hidden_states_per_batch is not None + mprint(f""" + +################################################################ +validate_model_with_kl_div({model_name=}, {is_calc_kl_div=}) +################################################################ + +""") + losses, _ = validate_model.validate_model( + args, + model, + tokenizer, + target_hidden_states_per_batch=target_hidden_states_per_batch, + runtime=runtime, + calculate_full_score_ablations=calculate_full_score_ablations, + val_dataloader=val_dataloader, + ) + if runtime is None or runtime.is_last_process: + extra_payload = extra_payload if (extra_payload is not None) else dict() + write_results(output_dir, model_name, args, {**losses, **extra_payload}) + + +def write_results( + output_dir: Union[str, Path], + result_name: str, + args: argparse.Namespace, + payload: dict[str, Any], +) -> None: + output_path = Path(output_dir) / f"{result_name}.json" + output_path.parent.mkdir(parents=True, exist_ok=True) + results = { + **payload, + "args": OmegaConf.to_container(args, resolve=True) + if isinstance(args, DictConfig) + else args.__dict__, + } + json_dump(results, output_path) From 56d95de0eac9dc2cbd79e5f2d8bcc9649c828329 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 2 Dec 2025 19:59:04 +0100 Subject: [PATCH 21/62] Add checkpoint save/load to ForwardHook + add IterativeChannelContributionHook (#610) ## What does this PR do? Add checkpoint save/load to ForwardHook + add IterativeChannelContributionHook. --------- Signed-off-by: Daniel Korzekwa --- modelopt/torch/nas/plugins/megatron.py | 23 +- modelopt/torch/nas/plugins/megatron_hooks.py | 266 +++++++++++++++++- .../torch/nas/plugins/test_megatron_hooks.py | 121 ++++++++ .../test_mcore_gpt_minitron_pruning.py | 8 +- 4 files changed, 397 insertions(+), 21 deletions(-) create mode 100644 tests/gpu/torch/nas/plugins/test_megatron_hooks.py diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 9796c5289e..0eccc11bc9 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -268,15 +268,12 @@ def _setup(self): max_ffn_size = int(self.get_hparam(self.hparam_name).max) # type: ignore[arg-type] activation_hook = MegatronL2NormHook(max_size=max_ffn_size) self._register_temp_attribute("_activation_hook", activation_hook) - # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute? + # _register_temp_attribute would not be enough instead of self.hook_handle to remove the hook from the module. self.hook_handle = self.linear_fc2.register_forward_hook(activation_hook) ffn_hidden_size.register_importance(self._estimate_importance) def _estimate_importance(self) -> TracedHp.Importance: """Return the activation magnitude-based importance of the ffn_hidden_size.""" - assert self._activation_hook._activations is not None, ( - "No activations collected for importance estimation." - ) return self._activation_hook.accumulate() def set_hidden_size_hp(self, hidden_size: TracedHp) -> None: @@ -597,7 +594,6 @@ def _setup(self): max_size = num_heads_per_group_max * num_query_groups_max * self.config.kv_channels activation_hook = MegatronL2NormHook(max_size=max_size) self._register_temp_attribute("_activation_hook", activation_hook) - # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute? self.hook_handle = self.linear_proj.register_forward_hook(activation_hook) # NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads, # otherwise we would only have aggregated importance of heads per group. @@ -607,9 +603,6 @@ def _setup(self): def _estimate_all_head_importance(self) -> TracedHp.Importance: """Return the importance for num_attention_heads (num_heads_per_group * num_query_groups).""" - assert self._activation_hook._activations is not None, ( - "No activations collected for importance estimation." - ) # Convert squared sum to L2 norm scores = self._activation_hook.accumulate() attn_head_importance = torch.linalg.vector_norm( @@ -625,9 +618,6 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance: def _estimate_query_group_importance(self) -> TracedHp.Importance: """Return the importance of the ``num_query_groups`` hparam.""" - assert self._activation_hook._activations is not None, ( - "No activations collected for importance estimation." - ) # Convert squared sum to L2 norm scores = self._activation_hook.accumulate() group_importance = torch.linalg.vector_norm( @@ -1552,7 +1542,7 @@ def export(self) -> torch.nn.Module: def get_activations_and_layer_scores( self, - ) -> tuple[list[dict[str, torch.Tensor]], dict[int, torch.Tensor]]: + ) -> tuple[list[dict[str, dict]], dict[int, torch.Tensor]]: """Get the per-rank activations and layer scores from the module.""" local_activations = {} for n, m in self.named_modules(): @@ -1560,7 +1550,8 @@ def get_activations_and_layer_scores( if hasattr(m, "_activations"): local_activations[n] = m._activations elif hasattr(m, "_activation_hook"): - local_activations[n] = m._activation_hook._activations + local_activations[n] = m._activation_hook.state_dict() + activations_per_rank = dist.allgather( local_activations, group=get_pipeline_model_parallel_group() ) @@ -1572,14 +1563,14 @@ def get_activations_and_layer_scores( def set_activations_and_layer_scores( self, - activations_per_rank: list[dict[str, torch.Tensor]], + activations_per_rank: list[dict[str, dict]], layer_scores: dict[int, torch.Tensor], ) -> None: """Set the pre-computed layer_scores and per-rank activations instead of running forward. Args: layer_scores: Dict from layer_number (1-indexed) to score. - activations_per_rank: List of dicts from module name to activations. Should match PP size. + activations_per_rank: List of dicts from module name to state dict. Should match PP size. """ rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() @@ -1593,7 +1584,7 @@ def set_activations_and_layer_scores( if hasattr(m, "_activations"): m._activations = activations_per_rank[rank][n] elif hasattr(m, "_activation_hook"): - m._activation_hook._activations = activations_per_rank[rank][n] + m._activation_hook.load_state_dict(activations_per_rank[rank][n]) def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None: diff --git a/modelopt/torch/nas/plugins/megatron_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks.py index 833e03c042..12e07c59ea 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks.py @@ -14,13 +14,27 @@ # limitations under the License. """Forward hooks for activation-based importance estimation (megatron NAS plugin).""" +import gc from abc import ABC, abstractmethod import torch +import torch.nn.functional as F from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region +from megatron.core.tensor_parallel.layers import RowParallelLinear from torch import nn +def clear_gpu_memory(clear: bool) -> None: + """Clear GPU memory cache if requested. + + Args: + clear: If True, runs garbage collection and empties CUDA cache. + """ + if clear: + gc.collect() + torch.cuda.empty_cache() + + class ForwardHook(ABC): """Base class for PyTorch forward hooks. @@ -48,6 +62,40 @@ def __call__( """ ... + @abstractmethod + def accumulate(self) -> torch.Tensor: + """Return accumulated importance scores. + + This method should be called after all forward passes to retrieve + the final importance scores for each channel/feature. + + Returns: + Tensor of importance scores, one per channel/feature. + + Raises: + AssertionError: If no activations have been collected yet. + """ + ... + + @abstractmethod + def state_dict(self) -> dict: + """Return the internal state for checkpointing. + + Returns: + dict: State dictionary containing checkpoint data. + Can contain tensors, ints, lists, etc. + """ + ... + + @abstractmethod + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint. + + Args: + state_dict: State dictionary previously returned by state_dict() + """ + ... + class MegatronL2NormHook(ForwardHook): """Hook for accumulating activation statistics for importance estimation. @@ -68,7 +116,14 @@ def __init__(self, max_size: int | None = None): def __call__( self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor ) -> None: - """Accumulate activation statistics from the forward pass.""" + """Accumulate activation statistics from the forward pass. + + Args: + module: The module this hook is registered on. + args: Tuple of input tensors. args[0] expected shape: [seq_len, batch_size, hidden_size] + (Megatron sequence-first format). + output: Output tensor from the module's forward pass. + """ # Gather input [seq_len, batch_size, hidden_size] over all TP regions # NOTE: This is not used at the moment since we restrict to TP=1 input_tensor = gather_from_tensor_model_parallel_region(args[0]).detach() @@ -102,3 +157,212 @@ def accumulate(self) -> torch.Tensor: assert self._activations is not None, "No activations collected for importance estimation." # Convert squared sum to L2 norm return self._activations.pow(0.5) + + def state_dict(self) -> dict: + """Return the state dictionary containing activations.""" + return {"activations": self._activations} + + def load_state_dict(self, state_dict: dict) -> None: + """Load activations from checkpoint.""" + self._activations = state_dict["activations"] + + +def get_pruning_schedule(num_channels, pruning_iters): + """Spending decreases monotonically when num_channels >= pruning_iters. + + Intervals between spends increase monotonically when pruning_iters > num_channels. + The budget is fully utilized, and there's spending in the last iteration. + num_channels = 10, pruning_iters = 4 ==> [3, 3, 2, 2] + num_channels = 4, pruning_iters = 10 ==> [0, 1, 0, 1, 0, 0, 1, 0, 0, 1] + """ + if num_channels >= pruning_iters: + # Case when budget is greater than or equal to iterations + q = num_channels // pruning_iters # Base spend per iteration + r = num_channels % pruning_iters # Remainder to distribute + + schedule = [] + for i in range(pruning_iters): + if i < r: + # Assign higher spend to earlier iterations + schedule.append(q + 1) + else: + schedule.append(q) + else: + # Case when iterations are greater than budget + schedule = [0] * pruning_iters + for i in range(1, num_channels + 1): + # Distribute spends at positions where intervals increase monotonically + pos = ((i * pruning_iters) // num_channels) - 1 + schedule[pos] = 1 + return schedule + + +class IterativeChannelContributionHook(ForwardHook): + """Hook for iterative channel pruning based on contribution analysis. + + Progressively identifies and removes the least important input channels of a linear layer + by measuring channel contribution as the L2 norm of output change when removed. + + Args: + linear_layer: The linear projection layer to analyze. Can be either nn.Linear or + RowParallelLinear from megatron.core.tensor_parallel.layers. + activation_hooks_kwargs: Configuration dict with: + - validation_full_iters (int): Number of pruning iterations. + - clear_gpu_memory (bool, optional): Clear GPU memory during computation. + - calibration_method (str, optional): "scale_by_magnitude" or None. + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__( + self, + linear_layer: nn.Linear | RowParallelLinear, + activation_hooks_kwargs: dict, + max_size: int | None = None, + ): + """Initialize the iterative channel contribution hook.""" + self.weight_matrix = linear_layer.weight + + # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) + # TODO: Consider better design to handle RowParallelLinear and nn.Linear + if hasattr(linear_layer, "input_size"): + self.num_channels = linear_layer.input_size # Megatron-Core + else: + self.num_channels = linear_layer.in_features # PyTorch + + self.max_size = max_size + self.pruning_iters = activation_hooks_kwargs["validation_full_iters"] + self.clear_gpu_memory = activation_hooks_kwargs.get("clear_gpu_memory", False) + self.curr_iter = 0 + self.pruning_schedule = get_pruning_schedule( + num_channels=self.num_channels, pruning_iters=self.pruning_iters + ) + + self.agg_cont_per_channel = torch.zeros( + size=(self.num_channels,), + dtype=torch.float32, + device=self.weight_matrix.device, + ) + self.pruned_channels = [] + self.calibration_method = activation_hooks_kwargs.get("calibration_method") + self.epsilon = 1e-8 + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple + ) -> None: + """Compute channel contributions and prune channels according to schedule. + + Args: + module: The module this hook is registered on. + args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] + (PyTorch batch-first format). + output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) + for parallel layers. + """ + # Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear) + # TODO: Consider better design to handle RowParallelLinear and nn.Linear + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + + activations = args[0] + + # Don't aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and activations.shape[-1] != self.max_size: + return + + n_channels_to_prune = self.pruning_schedule[self.curr_iter] + + curr_activations = activations.clone() # Shape B,T,I + curr_activations[..., self.pruned_channels] = 0 + output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E + + if self.calibration_method is None: + scaling_factor_per_token = torch.ones_like(output_tensor[..., 0]) # Shape B,T + elif self.calibration_method == "scale_by_magnitude": + output_norms = torch.linalg.vector_norm(output_tensor, dim=-1) # Shape B,T + output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T + scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon) + del output_curr_norms, output_norms + else: + raise NotImplementedError + del curr_activations + clear_gpu_memory(clear=self.clear_gpu_memory) + + s = scaling_factor_per_token.unsqueeze(-1) * output_tensor - output_curr # Shape: (B, T, E) + s_squared_per_token = torch.sum(s**2, dim=-1) # Shape: (B, T) + b = s @ self.weight_matrix # Shape: (B, T, I) + c = torch.sum(self.weight_matrix**2, dim=0) # Shape: (I) + del s, output_curr + clear_gpu_memory(clear=self.clear_gpu_memory) + + contribution_squared = ( + s_squared_per_token.unsqueeze(2) + 2 * activations * b + (activations**2) * c + ) # Shape: (B, T, I) + del s_squared_per_token, b, c, activations + clear_gpu_memory(clear=self.clear_gpu_memory) + + contribution = torch.sqrt(contribution_squared + self.epsilon) # Shape: (B, T, I) + mean_cont_per_channel = torch.mean(contribution, dim=(0, 1)) # Shape: (I) + mean_cont_per_channel[self.pruned_channels] = torch.inf + del contribution, contribution_squared + clear_gpu_memory(clear=self.clear_gpu_memory) + + if n_channels_to_prune == 0: + self.agg_cont_per_channel += mean_cont_per_channel + else: + _, worst_indices = torch.topk(mean_cont_per_channel, n_channels_to_prune, largest=False) + worst_indices_list = worst_indices.tolist() + assert not set(self.pruned_channels).intersection(set(worst_indices_list)) + self.pruned_channels.extend(worst_indices_list) + self.agg_cont_per_channel.zero_() + self.curr_iter += 1 + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert pruning results to dict with channel importance rankings. + + Returns: + Dict with "score" (importance rank per channel) and + "channels_importance_ascending" (channel indices in ascending importance). + """ + assert self.num_channels == len(self.pruned_channels) + channels_importance_ascending = torch.tensor(self.pruned_channels, dtype=torch.long) + score = torch.empty(self.num_channels, dtype=torch.long) + score[channels_importance_ascending] = torch.arange(self.num_channels, dtype=torch.long) + + return { + "score": score.cpu(), + "channels_importance_ascending": channels_importance_ascending.cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return importance scores as a tensor. + + Returns: + Tensor of importance scores, one per channel. Lower scores indicate less important channels. + """ + return self.to_dict()["score"] + + def state_dict(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "curr_iter": self.curr_iter, + "pruned_channels": self.pruned_channels.copy(), + "agg_cont_per_channel": self.agg_cont_per_channel.cpu().clone(), + "num_channels": self.num_channels, + "pruning_iters": self.pruning_iters, + "pruning_schedule": self.pruning_schedule.copy(), + "calibration_method": self.calibration_method, + "epsilon": self.epsilon, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.curr_iter = state_dict["curr_iter"] + self.pruned_channels = state_dict["pruned_channels"].copy() + self.agg_cont_per_channel = state_dict["agg_cont_per_channel"].to(self.weight_matrix.device) + # Verify other parameters match + assert self.num_channels == state_dict["num_channels"], "Channel count mismatch" + assert self.pruning_iters == state_dict["pruning_iters"], "Iteration count mismatch" + assert self.pruning_schedule == state_dict["pruning_schedule"], "Pruning schedule mismatch" diff --git a/tests/gpu/torch/nas/plugins/test_megatron_hooks.py b/tests/gpu/torch/nas/plugins/test_megatron_hooks.py new file mode 100644 index 0000000000..f94f2e85f4 --- /dev/null +++ b/tests/gpu/torch/nas/plugins/test_megatron_hooks.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for megatron hooks.""" + +import torch +import torch.nn as nn +from _test_utils.import_helper import skip_if_no_megatron + +skip_if_no_megatron() + +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from megatron.core.parallel_state import initialize_model_parallel + +from modelopt.torch.nas.plugins.megatron_hooks import ( + IterativeChannelContributionHook, + MegatronL2NormHook, +) + + +def _test_iterative_channel_contribution_hook_with_shape(dim1: int, dim2: int): + """Helper function to test IterativeChannelContributionHook with given activation shape. + + Args: + dim1: First dimension of activation tensor (before in_features). + dim2: Second dimension of activation tensor (before in_features). + """ + torch.manual_seed(42) + + linear_layer = nn.Linear(in_features=6, out_features=4, bias=False) + activation_hooks_kwargs = { + "validation_full_iters": 3, + "clear_gpu_memory": False, + "calibration_method": None, + } + hook = IterativeChannelContributionHook(linear_layer, activation_hooks_kwargs) + linear_layer.register_forward_hook(hook) + + for _ in range(activation_hooks_kwargs["validation_full_iters"]): + activations = torch.randn(dim1, dim2, linear_layer.in_features) + _ = linear_layer(activations) + + results = hook.to_dict() + + # + # Assertions + # + assert results["score"].shape == (6,) + assert results["channels_importance_ascending"].shape == (6,) + + expected_scores = torch.tensor([5, 1, 3, 2, 4, 0]) + assert torch.equal(results["score"], expected_scores) + + expected_channels_asc = torch.tensor([5, 1, 3, 2, 4, 0]) + assert torch.equal(results["channels_importance_ascending"], expected_channels_asc) + + # Test that accumulate() returns the same scores as to_dict()["score"] + scores_from_accumulate = hook.accumulate() + assert torch.equal(scores_from_accumulate, expected_scores) + + +def test_iterative_channel_contribution_hook_sbi(): + """Test IterativeChannelContributionHook returns correct scores for input [seq_len, batch_size, in_features].""" + _test_iterative_channel_contribution_hook_with_shape(dim1=32, dim2=8) + + +def test_iterative_channel_contribution_hook_bsi(): + """Test IterativeChannelContributionHook returns correct scores for input [batch_size, seq_len, in_features].""" + _test_iterative_channel_contribution_hook_with_shape(dim1=8, dim2=32) + + +def _test_l2_norm_hook(rank, size): + """Internal test function that runs in spawned process with distributed setup.""" + # Initialize Megatron parallel state (distributed is already initialized by spawn_multiprocess_job) + initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + torch.manual_seed(42) + + linear_layer = nn.Linear(in_features=6, out_features=4, bias=False) + hook = MegatronL2NormHook(max_size=None) + linear_layer.register_forward_hook(hook) + + num_iterations = 3 + for _ in range(num_iterations): + activations = torch.randn(2, 3, linear_layer.in_features) + _ = linear_layer(activations) + + scores = hook.accumulate() + + # + # Assertions + # + assert scores.shape == (6,) + + expected_scores = torch.tensor( + [3.2030, 2.5018, 2.5272, 1.9222, 2.6204, 2.2623], dtype=torch.float32 + ) + assert torch.allclose(scores, expected_scores, atol=1e-4), ( + f"Expected scores {expected_scores}, got {scores}" + ) + + +def test_l2_norm_hook(): + """Test MegatronL2NormHook returns correct scores after accumulating activations.""" + spawn_multiprocess_job( + size=1, + job=_test_l2_norm_hook, + backend="gloo", + ) diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 094fc015d0..6f7288663e 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -149,13 +149,13 @@ def forward_loop(m): assert pruning_scores["layer_scores"][2] == pytest.approx(1.7638601660728455, abs=1e-3) # Validate decoder.layers.0.mlp activations - mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"] + mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"]["activations"] assert mlp_0_acts.min().item() == pytest.approx(0.0015609927941114, abs=1e-3) assert mlp_0_acts.max().item() == pytest.approx(0.3844809532165527, abs=1e-3) assert mlp_0_acts.mean().item() == pytest.approx(0.0629318505525589, abs=1e-3) # Validate decoder.layers.1.mlp activations - mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"] + mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"]["activations"] assert mlp_1_acts.min().item() == pytest.approx(0.0001484956446802, abs=1e-3) assert mlp_1_acts.max().item() == pytest.approx(0.7835369110107422, abs=1e-3) assert mlp_1_acts.mean().item() == pytest.approx(0.0926810950040817, abs=1e-3) @@ -168,7 +168,7 @@ def forward_loop(m): # Validate decoder.layers.0.self_attention activations assert "decoder.layers.0.self_attention" in rank_0_activations - attn_0_acts = rank_0_activations["decoder.layers.0.self_attention"] + attn_0_acts = rank_0_activations["decoder.layers.0.self_attention"]["activations"] assert attn_0_acts.shape == torch.Size([256]) assert attn_0_acts.min().item() == pytest.approx(0.0409194342792034, abs=1e-3) assert attn_0_acts.max().item() == pytest.approx(0.5261313319206238, abs=1e-3) @@ -176,7 +176,7 @@ def forward_loop(m): # Validate decoder.layers.1.self_attention activations assert "decoder.layers.1.self_attention" in rank_0_activations - attn_1_acts = rank_0_activations["decoder.layers.1.self_attention"] + attn_1_acts = rank_0_activations["decoder.layers.1.self_attention"]["activations"] assert attn_1_acts.shape == torch.Size([256]) assert attn_1_acts.min().item() == pytest.approx(0.1189328655600548, abs=1e-3) assert attn_1_acts.max().item() == pytest.approx(1.3832759857177734, abs=1e-3) From 74aae832dcfcb08ef2f806a6a68b159975fb0777 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 4 Dec 2025 21:41:53 +0100 Subject: [PATCH 22/62] Add MIP step to the compress algorithm (#627) ## What does this PR do? Add MIP step to the compress algorithm. --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Daniel Korzekwa Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/workflows/gpu_tests.yml | 2 + modelopt/torch/_compress/compress.py | 8 +- .../_compress/mip/constrain_search_space.py | 407 +++++++++ ...dy_search_with_multi_layer_replacements.py | 180 ++++ .../torch/_compress/mip/grouped_knapsack.py | 231 +++++ .../_compress/mip/mip_and_realize_models.py | 90 ++ .../mip/mip_with_multi_layer_replacements.py | 198 +++++ modelopt/torch/_compress/mip/run_puzzle.py | 839 ++++++++++++++++++ modelopt/torch/_compress/mip/utils.py | 75 ++ .../nas/plugins/compress_nas_plugin.py | 2 +- .../replacement_library/replacement_utils.py | 4 +- modelopt/torch/_compress/utils/utils.py | 46 + setup.py | 1 + .../torch/_compress/compress_test_utils.py | 2 +- .../torch/_compress/conftest.py | 0 ..._convert_llama3_config_to_decilm_config.py | 2 +- .../_compress/nas/plugins/test_nas_convert.py | 12 +- .../_compress/nas/plugins/test_nas_search.py | 8 +- .../configs/Llama-3_1-8B-attn-pruning.yaml | 0 .../configs/Llama-3_1-8B-ffn-pruning.yaml | 0 .../configs/pruning/attn_pruning.yaml | 0 .../configs/pruning/ffn_pruning.yaml | 0 .../configs/pruning/hidden_dim_pruning.yaml | 0 .../configs/pruning/pruning_defaults.yaml | 0 .../configs/validate_model_defaults.yaml | 0 .../configs/validate_solutions_defaults.yaml | 0 .../tokenizer/special_tokens_map.json | 0 .../resources/tokenizer/tokenizer.json | 0 .../resources/tokenizer/tokenizer_config.json | 0 .../resources/tokenizer/truncate_tokenizer.py | 0 .../torch/_compress/test_compress.py | 8 +- 31 files changed, 2088 insertions(+), 27 deletions(-) create mode 100644 modelopt/torch/_compress/mip/constrain_search_space.py create mode 100644 modelopt/torch/_compress/mip/greedy_search_with_multi_layer_replacements.py create mode 100644 modelopt/torch/_compress/mip/grouped_knapsack.py create mode 100644 modelopt/torch/_compress/mip/mip_and_realize_models.py create mode 100644 modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py create mode 100644 modelopt/torch/_compress/mip/run_puzzle.py create mode 100644 modelopt/torch/_compress/mip/utils.py rename tests/{experimental => gpu}/torch/_compress/compress_test_utils.py (98%) rename tests/{experimental => gpu}/torch/_compress/conftest.py (100%) rename tests/{experimental => gpu}/torch/_compress/nas/plugins/test_nas_convert.py (92%) rename tests/{experimental => gpu}/torch/_compress/nas/plugins/test_nas_search.py (92%) rename tests/{experimental => gpu}/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml (100%) rename tests/{experimental => gpu}/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml (100%) rename tests/{experimental => gpu}/torch/_compress/resources/configs/pruning/attn_pruning.yaml (100%) rename tests/{experimental => gpu}/torch/_compress/resources/configs/pruning/ffn_pruning.yaml (100%) rename tests/{experimental => gpu}/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml (100%) rename tests/{experimental => gpu}/torch/_compress/resources/configs/pruning/pruning_defaults.yaml (100%) rename tests/{experimental => gpu}/torch/_compress/resources/configs/validate_model_defaults.yaml (100%) rename tests/{experimental => gpu}/torch/_compress/resources/configs/validate_solutions_defaults.yaml (100%) rename tests/{experimental => gpu}/torch/_compress/resources/tokenizer/special_tokens_map.json (100%) rename tests/{experimental => gpu}/torch/_compress/resources/tokenizer/tokenizer.json (100%) rename tests/{experimental => gpu}/torch/_compress/resources/tokenizer/tokenizer_config.json (100%) rename tests/{experimental => gpu}/torch/_compress/resources/tokenizer/truncate_tokenizer.py (100%) rename tests/{experimental => gpu}/torch/_compress/test_compress.py (93%) diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index 2ffb738c48..2dba922d43 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -74,6 +74,8 @@ jobs: - name: Setup environment variables run: | echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu" >> $GITHUB_ENV + - name: Install dependencies for mip + run: apt-get update && apt-get install -y libffi-dev - name: Run gpu tests run: pip install tox-current-env && tox -e py312-cuda12-gpu --current-env gpu-tests-non-pr: diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 765e3d6d42..8504631cbc 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -20,13 +20,13 @@ """ -import build_library_and_stats -import mip_and_realize_models -import pruning_ckpts -import scoring from omegaconf import DictConfig import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch._compress.build_library_and_stats as build_library_and_stats +import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch._compress.scoring.scoring as scoring from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir from modelopt.torch._compress.tools.runtime import IRuntime diff --git a/modelopt/torch/_compress/mip/constrain_search_space.py b/modelopt/torch/_compress/mip/constrain_search_space.py new file mode 100644 index 0000000000..e30ee24783 --- /dev/null +++ b/modelopt/torch/_compress/mip/constrain_search_space.py @@ -0,0 +1,407 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constrains the search space for the MIP optimization.""" + +import traceback + +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) +from modelopt.torch._compress.utils.utils import load_json + + +def drop_attentions_only(gathered_metrics, teacher_intermediate_size, teacher_n_heads_in_group): + """ + changes the search space such that puzzle is not allowed to change the ffns + but is only allowed to drop or reduce attention. + + Usage example: + add the following flags to your run_puzzle command: + + --constrain_search_func drop_attentions_only --constrain_search_args {\"teacher_intermediate_size\": 14336, \"teacher_n_heads_in_group\": 16, \"above_layer\": 60} + + """ + + for block_name, block_variants in gathered_metrics.items(): + to_delete = [] # Collect keys to delete after the loop + for variant_config, variant_metrics in block_variants.items(): + block_intermediate_size = variant_config.ffn.intermediate_size + block_attn_n_heads = variant_config.attention.n_heads_in_group + if ( + ( + block_intermediate_size is not None + and block_intermediate_size != teacher_intermediate_size + ) + or variant_config.ffn.replace_with_linear + or variant_config.ffn.no_op ## uncomment this line if you want to drop only attns + or variant_config.attention.replace_with_linear + or ( + block_attn_n_heads is not None + and block_attn_n_heads != teacher_n_heads_in_group + ) + ): + print(f"Marking for deletion: {block_name}-{variant_config}") + to_delete.append(variant_config) + for key in to_delete: + del block_variants[key] + + print("new search space in block 0", gathered_metrics["block_0"]) + return gathered_metrics + + +def reduce_only_ffns( + gathered_metrics, + teacher_intermediate_size: int, + teacher_n_heads_in_group: int, + above_layer: int, + allow_no_ops: bool, +): + """ + only allows to reduce FFNs but not to completely drop them from layer 60 onwards + attention is only allowed to be like uniform teacher + + Usage example: + add the following flags to your run_puzzle command: + constrain_search_args='{"teacher_intermediate_size": 14336, "teacher_n_heads_in_group": 16, "above_layer": 60, "allow_no_ops": false}' + + sbatch puzzle/cli/run_puzzle ... --constrain_search_func reduce_only_ffns --constrain_search_args="$(echo "$constrain_search_args" | jq -c .)" + """ + print(f"{teacher_n_heads_in_group=}") + for block_name, block_variants in gathered_metrics.items(): + to_delete = [] # Collect keys to delete after the loop + block_id = int(block_name.split("_")[1]) + + for variant_config, variant_metrics in block_variants.items(): + block_intermediate_size = variant_config.ffn.intermediate_size + block_attn_n_heads = variant_config.attention.n_heads_in_group + + attn_no_op = variant_config.attention.no_op + attn_linear = variant_config.attention.replace_with_linear + if ( + attn_no_op + or attn_linear + or (block_attn_n_heads != teacher_n_heads_in_group) # keep attention as the teacher + or ( + block_id <= above_layer + and (block_intermediate_size != teacher_intermediate_size) + ) + or ((not allow_no_ops) and variant_config.ffn.no_op) + ): + # print(f"Marking for deletion: {block_name}-{variant_config}") + to_delete.append(variant_config) # Add key to delete list + + for key in to_delete: + del block_variants[key] + + print("new search space in block 0", gathered_metrics["block_0"]) + return gathered_metrics + + +def drop_entire_blocks_only(gathered_metrics): + teacher_block_config = _infer_teacher_config(gathered_metrics) + for block_name, block_variants in gathered_metrics.items(): + to_delete = [] # Collect keys to delete after the loop + for variant_config, variant_metrics in block_variants.items(): + is_no_op_block = ( + variant_config.ffn.no_op + and variant_config.attention.no_op + and getattr(variant_config, "parallel_blocks", None) is None + ) + is_teacher = variant_config == teacher_block_config + if not is_no_op_block and not is_teacher: + to_delete.append(variant_config) + for key in to_delete: + del block_variants[key] + + print("new search space in block 0", gathered_metrics["block_0"]) + return gathered_metrics + + +def css_to_reference_attention(gathered_metrics, attention_pruned_arch): + """ + given a reference architecture we fix the search space to only include options that change the FFNs + but to never change the Attentions from the reference arch's Attentions. + """ + + attention_pruned_arch = load_json(attention_pruned_arch)[0] + attention_dropped_blocks = [ + block_name + for block_name, block_config in attention_pruned_arch["chosen_items"].items() + if block_config["attention"]["no_op"] + ] + + for block_name, block_variants in gathered_metrics.items(): + to_delete = [] # Collect keys to delete after the loop + for variant_config, _ in block_variants.items(): + # Uncomment and adjust this block if needed + # does drop only attention + block_attn_n_heads = variant_config.attention.n_heads_in_group + + reference_arch_attn = attention_pruned_arch["chosen_items"][block_name]["attention"][ + "n_heads_in_group" + ] + if ( # we reduce the search space by keeping the reference arch attention as is + (block_name in attention_dropped_blocks and not variant_config.attention.no_op) + or ( + block_name not in attention_dropped_blocks + and block_attn_n_heads != reference_arch_attn + ) + ): + print(f"Marking for deletion: {block_name}-{variant_config}") + to_delete.append(variant_config) + + # Delete marked keys outside the loop + for key in to_delete: + del block_variants[key] + + print("new search space in block 0", gathered_metrics["block_0"]) + return gathered_metrics + + +def css_to_reference_ffn(gathered_metrics, ffn_pruned_arch, allow_linear_attn=True): + """ + given a reference architecture we fix the search space to only include options that change the Attentions + but to never change the FFNs from the reference arch's FFNs. + """ + + ffn_pruned_arch = load_json(ffn_pruned_arch)[0] + + for block_name, block_variants in gathered_metrics.items(): + to_delete = [] # Collect keys to delete after the loop + for variant_config, _ in block_variants.items(): + block_ffn = variant_config.ffn + is_linear_attn = variant_config.attention.replace_with_linear + + reference_arch_ffn = ffn_pruned_arch["chosen_items"][block_name]["ffn"] + reference_arch_ffn = FFNConfig(**reference_arch_ffn) + + if ( # we reduce the search space by keeping the reference arch ffn as is + (block_ffn != reference_arch_ffn) or (not allow_linear_attn and is_linear_attn) + ): + # print(f"Marking for deletion: {block_name}-{variant_config}") + to_delete.append(variant_config) + + # Delete marked keys outside the loop + for key in to_delete: + del block_variants[key] + + print("new search space in block 0", gathered_metrics["block_0"]) + return gathered_metrics + + +def avoid_variable_gqa( + gathered_metrics, + allow_no_op_attn: bool = True, + allow_linear_attn: bool = False, + target_n_heads_in_group: int = None, +): + """ + Allow only the teacher n_heads_in_group, + and optionally also attention no-op (default allow) + and attention linear (default avoid). + + This reducer affects only the attention layers: FFNs are allowed their entire search space. + """ + is_multi_layer_puzzle = is_replacement_gathered_metrics(gathered_metrics) + if is_multi_layer_puzzle: + teacher_block_config = infer_teacher_replacement_config(gathered_metrics) + else: + teacher_block_config = _infer_teacher_config(gathered_metrics) + + if target_n_heads_in_group is None: + target_n_heads_in_group = teacher_block_config.attention.n_heads_in_group + + if not is_multi_layer_puzzle: + for block_name, block_variants in gathered_metrics.items(): + to_delete = [] # Collect keys to delete after the loop + + for variant_config, variant_metrics in block_variants.items(): + if not ( + (variant_config.attention.n_heads_in_group == target_n_heads_in_group) + or (variant_config.attention.no_op and allow_no_op_attn) + or (variant_config.attention.replace_with_linear and allow_linear_attn) + ): + to_delete.append(variant_config) + + for key in to_delete: + del block_variants[key] + else: + to_delete = [] # Collect keys to delete after the loop + for replacement_id, replacement in gathered_metrics.items(): + variant_config = replacement["block_config"] + if not ( + (variant_config.attention.n_heads_in_group == target_n_heads_in_group) + or (variant_config.attention.no_op and allow_no_op_attn) + or (variant_config.attention.replace_with_linear and allow_linear_attn) + ): + to_delete.append(replacement_id) + + for key in to_delete: + del gathered_metrics[key] + if not is_multi_layer_puzzle: + print("new search space in block 0", gathered_metrics["block_0"]) + else: + parent_layer_idx = 0 + print( + "new search space in block {parent_layer_idx}", + [ + replacement["block_config"] + for replacement_id, replacement in gathered_metrics.items() + if replacement["parent_layer_indices"][0] == parent_layer_idx + ], + ) + return gathered_metrics + + +def reduce_in_range( + gathered_metrics, + layer_start: int, + layer_end: int, +): + """ + Allow only reduction of layers between layer_start and layer_end. Leyers before layers start, and after layer_end are kept as is (the teacher). + + """ + assert layer_start < layer_end, ( + f"Wrong input arguments: {layer_start=} must be less than {layer_end=}" + ) + is_multi_layer_puzzle = is_replacement_gathered_metrics(gathered_metrics) + if is_multi_layer_puzzle: + teacher_block_config = infer_teacher_replacement_config(gathered_metrics) + else: + teacher_block_config = _infer_teacher_config(gathered_metrics) + + to_delete = [] # Collect keys to delete after the loop + for replacement_id, replacement in gathered_metrics.items(): + block_id = max(replacement["parent_layer_indices"]) + variant_config = replacement["block_config"] + is_teacher = variant_config == teacher_block_config + if (block_id < layer_start or block_id > layer_end) and not is_teacher: + to_delete.append(replacement_id) + + for key in to_delete: + del gathered_metrics[key] + + if not is_multi_layer_puzzle: + print("new search space in block 0", gathered_metrics["block_0"]) + else: + parent_layer_idx = 0 + print( + "new search space in block {parent_layer_idx}", + [ + replacement["block_config"] + for replacement_id, replacement in gathered_metrics.items() + if replacement["parent_layer_indices"][0] == parent_layer_idx + ], + ) + return gathered_metrics + + +############################################################################################# + + +# automatically builds a dictionary mapping method names in this module to their functions +# this dictionary is used to dynamically dispatch functions +dispatcher = { + method_name: method_callable + for method_name, method_callable in globals().items() + if callable(method_callable) +} + + +def is_replacement_gathered_metrics(gathered_metrics) -> bool: + # if the gathered metrics is a replacement, then it is a dictionary of the form {'replacement_{id}': replacement_metrics} + + return isinstance(gathered_metrics, dict) and all( + key.startswith("replacement_") for key in gathered_metrics + ) + + +def _infer_teacher_config(gathered_metrics) -> BlockConfig: + n_heads_in_group, intermediate_size = zip( + *[ + (variant_config.attention.n_heads_in_group, variant_config.ffn.intermediate_size) + for block_name, block_variants in gathered_metrics.items() + for variant_config, variant_metrics in block_variants.items() + ] + ) + teacher_n_heads_in_group = min(filter(None, n_heads_in_group)) + teacher_intermediate_size = max(filter(None, intermediate_size)) + + unique_teacher_candidates = set() + for block_name, block_variants in gathered_metrics.items(): + for variant_config, variant_metrics in block_variants.items(): + if ( + variant_config.ffn.intermediate_size == teacher_intermediate_size + and variant_config.attention.n_heads_in_group == teacher_n_heads_in_group + ): + unique_teacher_candidates.add(variant_config) + + assert len(unique_teacher_candidates) == 1, ( + f"Woops, expected example one candidate to be the teacher block config, instead found: {unique_teacher_candidates=}" + ) + + teacher_block_config = unique_teacher_candidates.pop() + return teacher_block_config + + +def infer_teacher_replacement_config(gathered_metrics) -> BlockConfig: + n_heads_in_group, intermediate_size = zip( + *[ + ( + replacement["block_config"].attention.n_heads_in_group, + replacement["block_config"].ffn.intermediate_size, + ) + for replacement_id, replacement in gathered_metrics.items() + ] + ) + teacher_intermediate_size = max(filter(None, intermediate_size)) + teacher_n_heads_in_group = min(filter(None, n_heads_in_group)) + unique_teacher_candidates = set() + for replacement_id, replacement in gathered_metrics.items(): + if ( + replacement["block_config"].ffn.intermediate_size == teacher_intermediate_size + and replacement["block_config"].attention.n_heads_in_group == teacher_n_heads_in_group + ): + unique_teacher_candidates.add(replacement["block_config"]) + + assert len(unique_teacher_candidates) == 1, ( + f"Woops, expected example one candidate to be the teacher block config, instead found: {unique_teacher_candidates=}" + ) + + teacher_replacement_config = unique_teacher_candidates.pop() + return teacher_replacement_config + + +def apply(css_func_name, gathered_metrics, method_kwargs): + search_space_reducer = dispatcher.get(css_func_name) + if search_space_reducer is None: + raise ValueError( + f"could not find a function called `{css_func_name}` in {__name__}.py to reduce search space " + ) + + try: + gathered_metrics = search_space_reducer(gathered_metrics, **method_kwargs) + except Exception as e: + traceback.print_exc() + raise ValueError( + f"something went wrong when trying to apply the following search space reducer `{css_func_name}` \ + with the folloing args: {method_kwargs}, here's the exception: {e}" + ) + + return gathered_metrics diff --git a/modelopt/torch/_compress/mip/greedy_search_with_multi_layer_replacements.py b/modelopt/torch/_compress/mip/greedy_search_with_multi_layer_replacements.py new file mode 100644 index 0000000000..719643cc22 --- /dev/null +++ b/modelopt/torch/_compress/mip/greedy_search_with_multi_layer_replacements.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performs greedy search to find optimal multi-layer replacements under resource constraints.""" + +# mypy: ignore-errors +import math +from copy import deepcopy +from random import random +from typing import Any, Hashable, TypeAlias + +from .utils import InfeasibleError, consecutive_ngrams, get_nested_key, sort_replacements + +ReplacementID: TypeAlias = Hashable +Replacement: TypeAlias = dict[str, Any] +ChosenReplacements: TypeAlias = list[Replacement] + + +def run_greedy_search( + teacher_replacements: list[Replacement], + student_replacements: list[Replacement], + objective: str, + constraints: dict[str, float], + bigger_is_better: bool, +) -> tuple[ChosenReplacements, float, dict[str, float]]: + print("####### running greedy search #######") + teacher_replacements = deepcopy(teacher_replacements) + student_replacements = deepcopy(student_replacements) + chosen_replacements: ChosenReplacements = [] + + teacher_replacements = { + replacement["parent_layer_indices"][0]: replacement for replacement in teacher_replacements + } + + all_parent_layers = set(teacher_replacements.keys()) + uncovered_parent_layers = set(all_parent_layers) + + while True: + if len(student_replacements) == 0: + raise InfeasibleError() + + choice_func = max if bigger_is_better else min + best_replacement = choice_func( + student_replacements, key=lambda replacement: get_nested_key(replacement, objective) + ) + chosen_replacements.append(best_replacement) + uncovered_parent_layers -= set(best_replacement["parent_layer_indices"]) + student_replacements = _filter_overlapping_replacements( + student_replacements, uncovered_parent_layers + ) + + padded_chosen_replacements = list(chosen_replacements) + for uncovered_block_idx in uncovered_parent_layers: + padded_chosen_replacements.append(teacher_replacements[uncovered_block_idx]) + + all_constraints_satisfied = True + for constraint_key, max_cost in constraints.items(): + total_cost = sum( + get_nested_key(replacement, constraint_key) + for replacement in padded_chosen_replacements + ) + is_constraint_satisfied = total_cost < max_cost or math.isclose( + total_cost, max_cost, rel_tol=1e-9 + ) + if not is_constraint_satisfied: + all_constraints_satisfied = False + + if all_constraints_satisfied: + chosen_replacements = padded_chosen_replacements + break + + # Trust But Verify: calculate total value and costs, and check that all the constraints are filled + total_value = 0.0 + total_costs = {constraint_key: 0 for constraint_key in constraints.keys()} + chosen_layers = set() + for replacement in chosen_replacements: + total_value += get_nested_key(replacement, objective) + for constraint_key in constraints.keys(): + total_costs[constraint_key] += get_nested_key(replacement, constraint_key) + for parent_layer_idx in replacement["parent_layer_indices"]: + assert parent_layer_idx not in chosen_layers, ( + f"Found duplicate chosen layer {parent_layer_idx}" + ) + chosen_layers.add(parent_layer_idx) + + missing_layers = all_parent_layers - set(chosen_layers) + assert len(missing_layers) == 0, ( + f"The following layers were not chosen by any replacement:\n{missing_layers=}\n{chosen_replacements}" + ) + + for constraint_key, max_cost in constraints.items(): + assert total_costs[constraint_key] < max_cost or math.isclose( + total_costs[constraint_key], max_cost, rel_tol=1e-9 + ), ( + f"this constraint was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} <= {max_cost=}" + ) + + chosen_replacements = sort_replacements(chosen_replacements) + for cr in chosen_replacements: + if "block_config" in cr: + cr["child_block_configs"] = cr["block_config"] + + return [ + { + "chosen_replacements": chosen_replacements, + "total_value": total_value, + "total_costs": total_costs, + } + ] + + +def _filter_overlapping_replacements( + replacements: list[Replacement], + uncovered_parent_layers: set[int], +) -> list[Replacement]: + return [ + replacement + for replacement in replacements + if set(replacement["parent_layer_indices"]).issubset(uncovered_parent_layers) + ] + + +def usage_example(): + num_layers = 32 + num_options_per_parent_replacement = 5 + + teacher_replacements = [] + student_replacements = [] + for num_layers_in_replacement in (1, 2, 3): + for i_option in range(num_options_per_parent_replacement): + for parent_layer_indices in consecutive_ngrams(num_layers, num_layers_in_replacement): + is_teacher = num_layers_in_replacement == 1 and i_option == 0 + replacement_id = f"parent layers {parent_layer_indices} child config {i_option}" + replacement = { + "parent_layer_indices": parent_layer_indices, + "metrics": {"loss": random() if not is_teacher else 0.0}, + "stats": {"cost": 1}, + "replacement_id": replacement_id, + } + if is_teacher: + teacher_replacements.append(replacement) + else: + student_replacements.append(replacement) + + constraints = {"stats.cost": num_layers - 8} + (result,) = run_greedy_search( + teacher_replacements, + student_replacements, + objective="metrics.loss", + constraints=constraints, + bigger_is_better=False, + ) + chosen_replacements = result["chosen_replacements"] + total_value = result["total_value"] + total_costs = result["total_costs"] + + print() + print() + print(f"{total_value=}") + print(f"{total_costs=}") + print(f"{constraints=}") + print("chosen_replacements=") + print(chosen_replacements) + print("\n".join([rep["replacement_id"] for rep in chosen_replacements])) + + +if __name__ == "__main__": + usage_example() diff --git a/modelopt/torch/_compress/mip/grouped_knapsack.py b/modelopt/torch/_compress/mip/grouped_knapsack.py new file mode 100644 index 0000000000..5769ded3cd --- /dev/null +++ b/modelopt/torch/_compress/mip/grouped_knapsack.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Solves the grouped knapsack problem using Mixed Integer Programming to find optimal item selections.""" + +# mypy: ignore-errors +import math +import warnings +from copy import deepcopy +from random import random +from typing import Any, Hashable, Iterable, Optional, TypeAlias, Union + +from mip import BINARY, Model, maximize, minimize, xsum +from tqdm import tqdm + +from .utils import InfeasibleError, get_nested_key + +Item: TypeAlias = dict[str, float | dict[str, float]] +Group: TypeAlias = dict[Hashable, Item] +ChosenItems: TypeAlias = dict[Hashable, Hashable] + + +def multi_solution_grouped_knapsack( + groups: dict[Hashable, Group], + objective: str, + constraints: dict[str, float], + bigger_is_better: bool, + num_solutions: int, + minimal_diversity: int = 1, + max_seconds_per_solution: Optional[float] = None, +) -> list[dict[str, Union[ChosenItems, float]]]: + solutions = [] + previous_choices = [] + for i_run in tqdm(range(num_solutions), desc="multi_solution_grouped_knapsack"): + try: + chosen_items, total_value, total_costs = grouped_knapsack( + groups, + objective, + constraints, + bigger_is_better, + previous_choices, + minimal_diversity, + max_seconds_per_solution, + ) + except InfeasibleError: + warnings.warn(f"Found only {i_run} feasible solutions (requested {num_solutions})") + break + previous_choices.append(chosen_items) + solutions.append( + {"chosen_items": chosen_items, "total_value": total_value, "total_costs": total_costs} + ) + return solutions + + +def grouped_knapsack( + groups: dict[Hashable, Group], + objective: str, + constraints: dict[str, float | tuple[float, float]], + bigger_is_better: bool, + previous_choices: Optional[list[ChosenItems]] = None, + minimal_diversity: int = 1, + max_seconds_per_solution: Optional[float] = None, +) -> tuple[ChosenItems, float, dict[str, float]]: + groups = deepcopy(groups) + mip_model = Model() + + objective_vars = [] + constraint_vars = {constraint_key: [] for constraint_key in constraints.keys()} + for group_name, group_items in groups.items(): + group_vars = [] + for item_name, item in group_items.items(): + is_chosen = mip_model.add_var(var_type=BINARY) + item["is_chosen"] = is_chosen + group_vars.append(is_chosen) + objective_vars.append(is_chosen * get_nested_objective(item, objective)) + for constraint_key in constraints.keys(): + constraint_vars[constraint_key].append( + is_chosen * get_nested_key(item, constraint_key) + ) + + mip_model += xsum(group_vars) == 1 + + for constraint_key, max_cost in constraints.items(): + min_cost = None + if isinstance(max_cost, Iterable): + min_cost, max_cost = max_cost + + if max_cost is not None: + mip_model += xsum(constraint_vars[constraint_key]) <= max_cost + if min_cost is not None: + mip_model += xsum(constraint_vars[constraint_key]) >= min_cost + + if previous_choices is not None: + for previous_chosen_items in previous_choices: + corresponding_vars = [ + groups[group_name][item_name]["is_chosen"] + for group_name, item_name in previous_chosen_items.items() + ] + mip_model += xsum(corresponding_vars) <= len(groups) - minimal_diversity + + mip_model.objective = ( + maximize(xsum(objective_vars)) if bigger_is_better else minimize(xsum(objective_vars)) + ) + + if max_seconds_per_solution is not None: + mip_model.max_seconds = max_seconds_per_solution + + mip_model.optimize() + + if is_chosen.x is None: + raise InfeasibleError() + + total_value = 0.0 + total_costs = {constraint_key: 0 for constraint_key in constraints.keys()} + chosen_items: ChosenItems = dict() + for group_name, group_items in groups.items(): + for item_name, item in group_items.items(): + is_chosen = item["is_chosen"].x >= 0.99 + if is_chosen: + assert group_name not in chosen_items + chosen_items[group_name] = item_name + total_value += get_nested_objective(item, objective) + for constraint_key in constraints.keys(): + total_costs[constraint_key] += get_nested_key(item, constraint_key) + + if len(chosen_items) != len(groups): + in_groups_and_not_in_chosen_items = set(groups.keys()) - set(chosen_items.keys()) + in_chosen_items_and_not_in_groups = set(chosen_items.keys()) - set(groups.keys()) + missing_groups = [groups[key] for key in in_groups_and_not_in_chosen_items] + raise RuntimeError(f""" + Different number of 'chosen_items' and 'groups': {len(chosen_items)=} {len(groups)=} + {in_groups_and_not_in_chosen_items=} + {in_chosen_items_and_not_in_groups=} + {missing_groups=} + """) + + for constraint_key, max_cost in constraints.items(): + min_cost = None + if isinstance(max_cost, Iterable): + min_cost, max_cost = max_cost + + if max_cost is not None: + assert total_costs[constraint_key] < max_cost or math.isclose( + total_costs[constraint_key], max_cost, rel_tol=1e-9 + ), ( + f"This max_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} > {max_cost=}" + ) + if min_cost is not None: + assert total_costs[constraint_key] > min_cost or math.isclose( + total_costs[constraint_key], min_cost, rel_tol=1e-9 + ), ( + f"This min_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} < {min_cost=}" + ) + + for previous_chosen_items in previous_choices: + num_differences = 0 + for group_name in groups.keys(): + num_differences += previous_chosen_items[group_name] != chosen_items[group_name] + assert num_differences >= minimal_diversity + + return chosen_items, total_value, total_costs + + +def get_nested_objective(dictionary: dict[str, Any], nested_key: str) -> Any: + if nested_key.startswith("metrics."): + # handle metrics that have '.' in their name + metric = nested_key.split("metrics.")[1] + return dictionary["metrics"][metric] + else: + return get_nested_key(dictionary, nested_key) + + +def usage_example(): + num_layers = 32 + num_configs_per_block = 100 + groups = { + f"layer_{i_layer}": { + f"config_{i_config}": { + "metrics": {"accuracy": random()}, + "stats": {"memory_mib": random() * 100, "runtime_ms": random() * 10}, + } + for i_config in range(num_configs_per_block) + } + for i_layer in range(num_layers) + } + + minimal_diversity = 10 + constraints = {"stats.memory_mib": num_layers * 50.0, "stats.runtime_ms": num_layers * 5.0} + solutions = multi_solution_grouped_knapsack( + groups, + objective="metrics.accuracy", + constraints=constraints, + bigger_is_better=True, + num_solutions=10, + minimal_diversity=minimal_diversity, + ) + + print() + print(constraints) + + for i_run, solution in enumerate(solutions): + print() + print(f"run {i_run}") + print(solution) + + print(f"Checking differences, should be at least {minimal_diversity}:") + for a in range(len(solutions)): + for b in range(a + 1, len(solutions)): + num_differences = 0 + for group_name in groups.keys(): + num_differences += ( + solutions[a]["chosen_items"][group_name] + != solutions[b]["chosen_items"][group_name] + ) + print(a, "<>", b, "=", num_differences) + + +if __name__ == "__main__": + usage_example() diff --git a/modelopt/torch/_compress/mip/mip_and_realize_models.py b/modelopt/torch/_compress/mip/mip_and_realize_models.py new file mode 100644 index 0000000000..83d8b23f56 --- /dev/null +++ b/modelopt/torch/_compress/mip/mip_and_realize_models.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs MIP (Mixed Integer Programming) optimization and realizes the resulting model solutions.""" + +# mypy: ignore-errors +from pathlib import Path +from typing import List + +import hydra +import torch +import torch.distributed as dist +from omegaconf import DictConfig + +from modelopt.torch._compress.mip.run_puzzle import run_puzzle +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import BaseRuntime, IRuntime, NativeDdpRuntime +from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import ( + validate_puzzle_solutions, +) +from modelopt.torch._compress.utils.dist_utils import is_distributed + + +def launch_mip(cfg: DictConfig) -> List[str]: + solution_paths = run_puzzle(args=cfg.mip) + return solution_paths + + +def launch_realize_model(cfg: DictConfig, runtime: IRuntime): + validate_puzzle_solutions(args=cfg.realize_model, runtime=runtime) + + +def launch_mip_and_realize_model(cfg: DictConfig, runtime: IRuntime): + if runtime.is_main_process: + solution_paths = launch_mip(cfg) + length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long) + else: + solution_paths = None + length_tensor = torch.tensor([0], dtype=torch.long) + + if not cfg.skip_realize_model: + if runtime.world_size > 1: + dist.broadcast(length_tensor, src=0) + + list_length = length_tensor.item() + + if runtime.global_rank != 0: + solution_paths = [None] * list_length + + if runtime.world_size > 1: + dist.broadcast_object_list(solution_paths, src=0) + + for solution_path in solution_paths: + mprint(f"Realize model for the solution: {solution_path}") + cfg.realize_model.solutions_path = Path(solution_path) + launch_realize_model(cfg, runtime=runtime) + runtime.wait_for_everyone() + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + + _runtime = ( + NativeDDP_Runtime( + dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") + ) + if is_distributed() + else BaseRuntime(dtype=torch.bfloat16) + ) + with _runtime as runtime: + launch_mip_and_realize_model(cfg, runtime) + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py b/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py new file mode 100644 index 0000000000..50525c846c --- /dev/null +++ b/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Solves multi-layer replacement optimization using Mixed Integer Programming.""" + +# mypy: ignore-errors +import math +import warnings +from collections import defaultdict +from copy import deepcopy +from random import random +from typing import Any, Hashable, Iterable, Optional, TypeAlias + +from mip import BINARY, Model, maximize, minimize, xsum + +from .utils import InfeasibleError, consecutive_ngrams, get_nested_key, sort_replacements + +ReplacementID: TypeAlias = Hashable +Replacement: TypeAlias = dict[str, Any] +ChosenReplacements: TypeAlias = list[Replacement] + + +def run_mip( + replacements: dict[ReplacementID, Replacement], + objective: str, + constraints: dict[str, float], + bigger_is_better: bool, + max_seconds_per_solution: Optional[float] = None, +) -> tuple[ChosenReplacements, float, dict[str, float]]: + orig_num_replacements = len(replacements) + replacements = { + replacement_id: deepcopy(replacement) + for replacement_id, replacement in replacements.items() + if math.isfinite(get_nested_key(replacement, objective)) + } + if len(replacements) < orig_num_replacements: + print("\n\n\n") + warnings.warn( + f"mip: removed {orig_num_replacements - len(replacements)} replacements with NaN/inf objective value" + ) + print("\n\n\n") + + mip_model = Model() + + objective_vars = [] + constraint_vars = {constraint_key: [] for constraint_key in constraints.keys()} + choice_indicators_by_layer = defaultdict(list) + for replacement_id, replacement in replacements.items(): + is_chosen = mip_model.add_var(var_type=BINARY) + replacement["is_chosen"] = is_chosen + + for parent_layer_idx in replacement["parent_layer_indices"]: + choice_indicators_by_layer[parent_layer_idx].append(is_chosen) + + objective_vars.append(is_chosen * get_nested_key(replacement, objective)) + + for constraint_key in constraints.keys(): + constraint_vars[constraint_key].append( + is_chosen * get_nested_key(replacement, constraint_key) + ) + + # MIP constraints: each parent layer must come from exactly one chosen replacement + for parent_layer_idx, curr_choice_indicators in choice_indicators_by_layer.items(): + mip_model += xsum(curr_choice_indicators) == 1 + + # MIP constraints: the sum of chosen replacement costs must be lower than the max cost + for constraint_key, max_cost in constraints.items(): + min_cost = None + if isinstance(max_cost, Iterable): + min_cost, max_cost = max_cost + + if max_cost is not None: + mip_model += xsum(constraint_vars[constraint_key]) <= max_cost + if min_cost is not None: + mip_model += xsum(constraint_vars[constraint_key]) >= min_cost + + # MIP objective + mip_model.objective = ( + maximize(xsum(objective_vars)) if bigger_is_better else minimize(xsum(objective_vars)) + ) + + if max_seconds_per_solution is not None: + mip_model.max_seconds = max_seconds_per_solution + + mip_model.optimize() + + if is_chosen.x is None: + return [] + # raise InfeasibleError() + + # Trust But Verify: calculate total value and costs, and check that all the constraints are filled + total_value = 0.0 + total_costs = {constraint_key: 0 for constraint_key in constraints.keys()} + chosen_replacements: ChosenReplacements = [] + chosen_layers = [] + for replacement_id, replacement in replacements.items(): + is_chosen = replacement["is_chosen"].x >= 0.99 + if is_chosen: + assert replacement not in chosen_replacements + chosen_replacements.append(replacement) + total_value += get_nested_key(replacement, objective) + for constraint_key in constraints.keys(): + total_costs[constraint_key] += get_nested_key(replacement, constraint_key) + for parent_layer_idx in replacement["parent_layer_indices"]: + assert parent_layer_idx not in chosen_layers + chosen_layers.append(parent_layer_idx) + + missing_layers = set(choice_indicators_by_layer.keys()) - set(chosen_layers) + assert len(missing_layers) == 0, ( + f"The following layers were not chosen by any replacement:\n{missing_layers=}\n{chosen_replacements}" + ) + + for constraint_key, max_cost in constraints.items(): + min_cost = None + if isinstance(max_cost, Iterable): + min_cost, max_cost = max_cost + + if max_cost is not None: + assert total_costs[constraint_key] < max_cost or math.isclose( + total_costs[constraint_key], max_cost, rel_tol=1e-9 + ), ( + f"This max_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} > {max_cost=}" + ) + if min_cost is not None: + assert total_costs[constraint_key] > min_cost or math.isclose( + total_costs[constraint_key], min_cost, rel_tol=1e-9 + ), ( + f"This min_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} < {min_cost=}" + ) + + chosen_replacements = sort_replacements(chosen_replacements) + for cr in chosen_replacements: + del cr["is_chosen"] # not copyable, will cause errors in deep copy + if "block_config" in cr: + cr["child_block_configs"] = cr["block_config"] + # del cr['block_config'] for now the dump includes both keys (duplicated values) # we might wanna either delete one of them or keep both + # I prefer keeping block_config and deleting 'child_block_configs' from previous puzzle steps + + return [ + { + "chosen_replacements": chosen_replacements, + "total_value": total_value, + "total_costs": total_costs, + } + ] + + +def usage_example(): + num_layers = 32 + num_options_per_parent_replacement = 5 + + replacements = dict() + for num_layers_in_replacement in (1, 2, 3): + for i_option in range(num_options_per_parent_replacement): + for parent_layer_indices in consecutive_ngrams(num_layers, num_layers_in_replacement): + replacement_id = f"parent layers {parent_layer_indices} child config {i_option}" + replacement = { + "parent_layer_indices": parent_layer_indices, + "metrics": {"loss": random()}, + "stats": {"memory_mib": random() * 100, "runtime_ms": random() * 10}, + "replacement_id": replacement_id, + } + replacements[replacement_id] = replacement + + constraints = {"stats.memory_mib": num_layers * 15.0, "stats.runtime_ms": num_layers * 1.5} + (result,) = run_mip( + replacements, + objective="metrics.loss", + constraints=constraints, + bigger_is_better=False, + ) + chosen_replacements = result["chosen_replacements"] + total_value = result["total_value"] + total_costs = result["total_costs"] + + print() + print() + print(f"{total_value=}") + print(f"{total_costs=}") + print(f"{constraints=}") + print("chosen_replacements=") + print("\n".join([rep["replacement_id"] for rep in chosen_replacements])) + + +if __name__ == "__main__": + usage_example() diff --git a/modelopt/torch/_compress/mip/run_puzzle.py b/modelopt/torch/_compress/mip/run_puzzle.py new file mode 100644 index 0000000000..fd883e969f --- /dev/null +++ b/modelopt/torch/_compress/mip/run_puzzle.py @@ -0,0 +1,839 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main entry point for running the puzzle optimization to find optimal layer configurations.""" + +# mypy: ignore-errors +import argparse +import dataclasses +import enum +import json +from copy import deepcopy +from pathlib import Path +from typing import Any, Hashable, Iterable, List, Literal, TypeAlias + +import numpy as np +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf + +import modelopt.torch._compress.mip.constrain_search_space as css +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) +from modelopt.torch._compress.mip.greedy_search_with_multi_layer_replacements import ( + run_greedy_search, +) +from modelopt.torch._compress.mip.mip_with_multi_layer_replacements import ( + run_mip as run_multi_layer_replacement_mip, +) +from modelopt.torch._compress.replacement_library.replacement_utils import ( + extract_block_configs_and_locations, + parse_layer_replacement, + replacement_is_teacher, +) +from modelopt.torch._compress.tools.checkpoint_utils import load_model_config +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.robust_json import json_dump +from modelopt.torch._compress.utils.parsing import get_nested_key, parse_json, parse_path +from modelopt.torch._compress.utils.utils import block_config_to_str, solution_to_str + +""" +Usage: +Must specify either --single_block_replacement_validation_dir and --subblock_stats_path (in which case the metrics will +be gathered from the validation output files) or --gathered_metrics_path (in which case the metrics will be read from +this json file). + +Constraints can be specified either as 'mip_constraints' (the actual constraints that go into the MIP, e.g. 'stats.memory_mib', 'stats.runtime_ms'), +or as "human constraints" (e.g. 'target_memory', 'target_throughput', for the full list see PuzzleConstraints._ALLOWED_HUMAN_CONSTRAINTS). + +""" + +PuzzleMetrics: TypeAlias = dict[Hashable, dict[Hashable, dict[str, float]]] +MultiLayerPuzzleMetrics: TypeAlias = dict[str, dict[str, Hashable]] + + +@dataclasses.dataclass +class PuzzleConstraints: + """A set of puzzle constraints can be expressed either directly as the mip constraints (e.g. 'stats.memory_mib') or as human constraints (e.g. 'target_throughput')""" + + class Type(enum.Enum): + MIP = "mip" + HUMAN = "human" + + _ALLOWED_HUMAN_CONSTRAINTS = { + "target_memory", + "target_throughput", + "target_latency", + "target_time_to_first_token", + "num_params", + "stats.has_attention", + } + type: Type + name: str = dataclasses.field(init=False) + constraints: dict[str, Any] + + @staticmethod + def sizeof_fmt(num, suffix=""): + for unit in ("", "K", "M", "G", "T"): + if abs(num) < 1000.0: + return f"{num:g}{unit}{suffix}" + num /= 1000.0 + return f"{num:.1f}P{suffix}" + + def _validate_human_constraints(self): + illegal_constraints = set(self.constraints.keys()) - self._ALLOWED_HUMAN_CONSTRAINTS + if illegal_constraints: + raise ValueError( + f"The following human_constraints are illegal: {','.join(illegal_constraints)}" + ) + + def format_num_params_to_float(self, num_params): + if isinstance(num_params, list): + return [self.format_num_params_to_float(x) for x in num_params] + if isinstance(num_params, str): + # we only deal with Billions of params scale + return float(num_params.replace("B", "")) * 1e9 + return num_params + + def format_num_params_to_str(self, num_params): + if isinstance(num_params, list): + return [self.format_num_params_to_str(x) for x in num_params] + if isinstance(num_params, float) or isinstance(num_params, int): + return f"{num_params / 1e9}B" + return num_params + + def __post_init__(self): + if self.type == PuzzleConstraints.Type.HUMAN: + self._validate_human_constraints() + + if "stats.active_params" in self.constraints: + self.constraints["stats.active_params"] = self.format_num_params_to_float( + self.constraints["stats.active_params"] + ) + + # Set self.name + constraints = deepcopy(self.constraints) # going to override with "human readable" versions + if "stats.active_params" in constraints: + constraints["stats.active_params"] = self.format_num_params_to_str( + constraints["stats.active_params"] + ) + + if self.type == PuzzleConstraints.Type.HUMAN: + # change values to a more human string form + if "target_memory" in constraints: + constraints["target_memory"] = str(constraints["target_memory"]) + "MiB" + if "num_params" in constraints: + constraints["num_params"] = self.sizeof_fmt(constraints["num_params"]) + + def build_constraint_name(constraint_name, constraint_value): + if isinstance(constraint_value, Iterable) and not isinstance(constraint_value, str): + return "-".join(f"{constraint_name}_{x}" for x in constraint_value) + else: + return f"{constraint_name}_{constraint_value}" + + self.name = "-".join(build_constraint_name(k, v) for k, v in constraints.items()).replace( + ".", "_" + ) + + def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]: + if self.type == PuzzleConstraints.Type.MIP: + return self.constraints + + assert all(key in subblock_stats_args for key in ("batch_size", "generation_seq_len")), ( + "Can't realize human constraints without 'block_size' and 'generation_seq_len' in subblock_stats_args." + ) + batch_size = subblock_stats_args["batch_size"] + generation_seq_len = subblock_stats_args["generation_seq_len"] + + mip_constraints = {} + + # Memory constraints + if "target_memory" in self.constraints: + mip_constraints["stats.memory_mib"] = self.constraints["target_memory"] + + # Throughput constraints + throughput_constraints = [] + if "target_throughput" in self.constraints: + throughput_constraints.append( + batch_size * generation_seq_len / self.constraints["target_throughput"] + ) + if "target_latency" in self.constraints: + throughput_constraints.append(self.constraints["target_latency"]) + if throughput_constraints: + mip_constraints["stats.runtime_ms"] = 1000 * min(throughput_constraints) + + # Prefill runtime constraint + if "target_time_to_first_token" in self.constraints: + mip_constraints["stats.prefill_runtime_ms"] = ( + 1000 * self.constraints["target_time_to_first_token"] + ) + + # Num params + if "num_params" in self.constraints: + mip_constraints["stats.num_params"] = self.constraints["num_params"] + if "stats.has_attention" in self.constraints: + mip_constraints["stats.has_attention"] = self.constraints["stats.has_attention"] + return mip_constraints + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + + parser.add_argument("--puzzle_profile", type=parse_path) + + parser.add_argument("--single_block_replacement_validation_dir", type=parse_path, default=None) + parser.add_argument( + "--gathered_metrics_path", + type=parse_path, + default=None, + help="Can be given explicitly instead of --single_block_replacement_validation_dir", + ) + + parser.add_argument("--subblock_stats_path", type=parse_path) + parser.add_argument("--subblock_stats_args", type=parse_json) + + parser.add_argument("--objective", type=str) + parser.add_argument("--mip_constraints", type=parse_json) + parser.add_argument("--human_constraints", type=parse_json) + parser.add_argument("--report_additional_costs", type=str, action="append", default=[]) + + parser.add_argument("--num_solutions", type=int) + parser.add_argument("--minimal_diversity", type=int) + parser.add_argument( + "--output_path", + type=parse_path, + help="The main folder under which all results will be stored.", + ) + + parser.add_argument("--max_seconds_per_solution", type=float, default=60.0) + parser.add_argument("--metric_overrides", type=parse_json, default=None) + parser.add_argument( + "--bigger_is_better", + action="store_true", + help="Set this if using accuracy objective, don't set if using loss objective", + ) + + parser.add_argument("--constrain_search_func", type=str, default=None) + parser.add_argument("--constrain_search_args", type=parse_json, default=dict()) + + parser.add_argument( + "--is_multi_layer_puzzle", + action="store_true", + default=True, + help="[DEPRECATED] This flag is now always True. Kept for backward compatibility.", + ) + parser.add_argument( + "--use_greedy_search", + action="store_true", + help="Use greedy search instead of mip. Only supported for multi-layer puzzle.", + ) + + args = parser.parse_args() + return args + + +def run_single_puzzle_config( + args: argparse.Namespace, + gathered_metrics: dict, + subblock_stats: dict, + subblock_stats_args: dict, + constraints: PuzzleConstraints, + output_folder, +) -> None: + from modelopt.torch._compress.mip.grouped_knapsack import multi_solution_grouped_knapsack + + args = deepcopy( + args + ) # we override the constraints and subblock_stats_args for this run to keep reporting out the same way. + + subblock_stats = filter_subblock_stats_by_args(subblock_stats, subblock_stats_args) + _add_block_stats_to_gathered_metrics(gathered_metrics, subblock_stats) + + output_folder.mkdir(parents=True, exist_ok=True) + _dump_gathered_metrics(gathered_metrics, output_folder, args.is_multi_layer_puzzle) + + non_block_stats = {"stats": _get_block_stats(subblock_stats, "non_block")} + batch_size = subblock_stats["args"]["batch_size"] + generation_seq_len = subblock_stats["args"]["generation_seq_len"] + + mip_constraints = constraints.to_mip_constraints(subblock_stats["args"]) + orig_mip_constraints = deepcopy(mip_constraints) + mprint(f"Solving for the following MIP constraints: {mip_constraints}") + args.mip_constraints = orig_mip_constraints + args.human_constraints = ( + constraints.constraints if constraints.type == PuzzleConstraints.Type.HUMAN else None + ) + args.subblock_stats_args = subblock_stats_args + + for stat_name, max_cost in mip_constraints.items(): + try: + non_block_cost = get_nested_key(non_block_stats, stat_name) + except KeyError: + non_block_cost = 0 + + is_min_max = isinstance(max_cost, Iterable) + min_cost = None + if is_min_max: + min_cost, max_cost = max_cost + + min_cost = min_cost - non_block_cost if (min_cost is not None) else None + max_cost = max_cost - non_block_cost if (max_cost is not None) else None + + if is_min_max: + mip_constraints[stat_name] = (min_cost, max_cost) + else: + mip_constraints[stat_name] = max_cost + + # If there's an additional cost that is not a constraint - set it to "inf" so MIP report the actual value of it. + for cost in set(args.report_additional_costs) - set(orig_mip_constraints.keys()): + mip_constraints[cost] = np.inf + + mprint(f"After non-block adjustments: {mip_constraints=}") + + if args.is_multi_layer_puzzle: + if not args.use_greedy_search: + solutions = run_multi_layer_replacement_mip( + replacements=gathered_metrics, + objective=args.objective, + constraints=mip_constraints, + bigger_is_better=args.bigger_is_better, + max_seconds_per_solution=args.max_seconds_per_solution, + ) + else: + teacher_replacements, student_replacements = [], [] + for replacement in gathered_metrics.values(): + if replacement["is_teacher"]: + teacher_replacements.append(replacement) + else: + student_replacements.append(replacement) + + solutions = run_greedy_search( + teacher_replacements=teacher_replacements, + student_replacements=student_replacements, + objective=args.objective, + constraints=mip_constraints, + bigger_is_better=args.bigger_is_better, + ) + else: + solutions = multi_solution_grouped_knapsack( + groups=gathered_metrics, + objective=args.objective, + constraints=mip_constraints, + bigger_is_better=args.bigger_is_better, + num_solutions=args.num_solutions, + minimal_diversity=args.minimal_diversity, + max_seconds_per_solution=args.max_seconds_per_solution, + ) + + for solution in solutions: + for stat_name in set([*orig_mip_constraints.keys(), *args.report_additional_costs]): + try: + non_block_cost = get_nested_key(non_block_stats, stat_name) + except KeyError: + non_block_cost = 0 + solution["total_costs"][stat_name] += non_block_cost + + # Calculate throughput from runtime_ms + if "stats.runtime_ms" in solution["total_costs"]: + total_runtime = solution["total_costs"]["stats.runtime_ms"] + solution["total_costs"]["throughput"] = ( + 1000 * batch_size * generation_seq_len / total_runtime + ) + + solution["total_value"] = {args.objective: solution["total_value"]} + solution["puzzle_args"] = ( + OmegaConf.to_container(args, resolve=True) + if isinstance(args, DictConfig) + else vars(args) + ) + solution["subblock_stats"] = subblock_stats["args"] + chosen_block_configs, _ = extract_block_configs_and_locations( + solution["chosen_replacements"] + ) + solution["chosen_block_configs"] = chosen_block_configs + solution["solution_repr"] = solution_to_str(chosen_block_configs) + + if len(solutions) > 0: + solution_repr_0 = solutions[0]["solution_repr"] + mprint(f"\n{solution_repr_0}") + mprint(f"Total costs: {solutions[0]['total_costs']}") + (output_folder / "solution_repr_0.txt").write_text(solution_repr_0) + + solutions_file = output_folder / "solutions.json" + json_dump(solutions, solutions_file) + mprint(solutions_file) + return solutions_file + + +def _dump_gathered_metrics( + gathered_metrics: PuzzleMetrics, output_folder: Path, is_multi_layer_puzzle: bool = False +) -> None: + if is_multi_layer_puzzle: + for replacement_id, replacement_info in gathered_metrics.items(): + replacement_info["block_repr"] = block_config_to_str(replacement_info["block_config"]) + gathered_metrics_for_dump = gathered_metrics + else: + gathered_metrics_for_dump = { + block_name: { + block_config_to_str(variant_config).strip(): { + **variant_metrics, + "block_config": variant_config, + "block_repr": block_config_to_str(variant_config).strip(), + } + for variant_config, variant_metrics in block_variants.items() + } + for block_name, block_variants in gathered_metrics.items() + } + + json_dump(gathered_metrics_for_dump, output_folder / "replacement_metrics_and_stats.json") + + +def _load_all_constraints(args, puzzle_profile): + def parse_constraints(constraints, constraints_type: PuzzleConstraints.Type): + if isinstance(constraints, (list, ListConfig)): + return [PuzzleConstraints(type=constraints_type, constraints=c) for c in constraints] + elif isinstance(constraints, (dict, DictConfig)): + return [PuzzleConstraints(type=constraints_type, constraints=constraints)] + raise TypeError(f"Invalid constraints type: {constraints_type}") + + # Constraints can be given explicitely + if args.mip_constraints is not None: + return parse_constraints(args.mip_constraints, PuzzleConstraints.Type.MIP) + + if args.human_constraints is not None: + return parse_constraints(args.human_constraints, PuzzleConstraints.Type.HUMAN) + + # Or through the puzzle_profile + if "mip_constraints" in puzzle_profile: + return parse_constraints(puzzle_profile["mip_constraints"], PuzzleConstraints.Type.MIP) + + if "human_constraints" in puzzle_profile: + return parse_constraints(puzzle_profile["human_constraints"], PuzzleConstraints.Type.HUMAN) + + raise ValueError( + "Constraints must be given either explicitely by --mip_constraints or --human_constraints arguments, or through the puzzle_profile." + ) + + +def _load_all_subblock_stats_args(args, puzzle_profile): + # If given explicitely in args + if args.subblock_stats_args is not None: + if isinstance(args.subblock_stats_args, dict): + return [args.subblock_stats_args] + else: + return args.subblock_stats_args + + # Or can be given inside puzzle_profile + if "subblock_stats_args" in puzzle_profile: + return puzzle_profile["subblock_stats_args"] + + raise ValueError( + "subblock_stats_args must be given either explicitely by the --subblock_stats_args argument, or through the puzzle_profile." + ) + + +def _override_args_from_profile(args, puzzle_profile): + for arg_name in vars(args): + if arg_name in puzzle_profile: + if arg_name not in ("mip_constraints", "human_constraints", "subblock_stats_args"): + setattr(args, arg_name, puzzle_profile[arg_name]) + if isinstance(args.constrain_search_args, str): + args.constrain_search_args = parse_json(args.constrain_search_args) + assert args.is_multi_layer_puzzle, "multi-layer puzzle is now the only supported mode." + + +def _assert_valid_config(args, puzzle_profile): + required_args = ( + "subblock_stats_path", + "objective", + "num_solutions", + "minimal_diversity", + "output_path", + ) + missing_args = [arg for arg in required_args if arg not in args or getattr(args, arg) is None] + if missing_args: + mprint(f"error: The following arguments are required: {', '.join(missing_args)}") + exit(1) + + # Make sure we have specified subblock_stats_args + if "subblock_stats_args" not in args and "subblock_stats_args" not in puzzle_profile: + mprint( + "error: Must specify `subblock_stats_arrs` in either puzzle_profile or as a commandline arg." + ) + exit(1) + + # Make sure we have specified constraints + if ( + "mip_constraints" not in args + and "human_constraints" not in args + and "mip_constraints" not in puzzle_profile + and "human_constraints" not in puzzle_profile + ): + mprint( + "error: Must specify either `mip_constraints` or `human_constraints` in one of puzzle_profile or as a commandline argument." + ) + exit(1) + + if args.use_greedy_search: + assert args.is_multi_layer_puzzle, ( + "--use_greedy_search is only supported for multi layer puzzle" + ) + + +def _get_minimal_unique_names(dicts: List[dict]) -> List[str]: + all_keys = set(k for d in dicts for k in d.keys()) + all_values = {k: set(d[k] for d in dicts if k in d) for k in all_keys} + non_common_keys = [k for k, values in all_values.items() if len(values) > 1] + + return ["-".join(f"{k}_{d[k]}".replace(".", "_") for k in non_common_keys) for d in dicts] + + +def run_puzzle(args: argparse.Namespace) -> List[str]: + # Loads config from args/puzzle_profile + if args.puzzle_profile is not None: + with open(args.puzzle_profile) as f: + puzzle_profile = yaml.safe_load(f) + _override_args_from_profile(args, puzzle_profile) + mprint(f"Loaded Puzzle profile from {args.puzzle_profile}") + else: + puzzle_profile = {} + _assert_valid_config(args, puzzle_profile) + + # Read Metrics and Stats + if args.gathered_metrics_path is not None: + gathered_metrics = json.loads(args.gathered_metrics_path.read_text()) + else: + gather_func = ( + gather_puzzle_metrics + if not args.is_multi_layer_puzzle + else gather_multi_layer_puzle_metrics + ) + gathered_metrics = gather_func(args.single_block_replacement_validation_dir) + + if args.metric_overrides is not None: + gathered_metrics = {**gathered_metrics, **args.metric_overrides} + + if args.constrain_search_func is not None: + mprint(f"{args.constrain_search_args=}") + # assert not args.is_multi_layer_puzzle, "conditional search is not implementd yet for multi-layer puzzles, did you implement it?" + gathered_metrics = css.apply( + args.constrain_search_func, gathered_metrics, args.constrain_search_args + ) + + subblock_stats = json.loads(args.subblock_stats_path.read_text()) + + all_subblock_args = _load_all_subblock_stats_args(args, puzzle_profile) + all_subblock_output_folders = [ + args.output_path / unique_name + for unique_name in _get_minimal_unique_names(all_subblock_args) + ] + all_constraints = _load_all_constraints(args, puzzle_profile) + + # Running all puzzles + solution_paths = [] + for subblock_stats_args, subblock_stats_output_folder in zip( + all_subblock_args, all_subblock_output_folders + ): + for constraints in all_constraints: + output_folder = subblock_stats_output_folder / constraints.name + _solution_path = run_single_puzzle_config( + args, + gathered_metrics, + subblock_stats, + subblock_stats_args, + constraints, + output_folder, + ) + solution_paths.append(_solution_path) + return solution_paths + + +def gather_puzzle_metrics( + single_block_replacement_validation_dir: Path, +) -> PuzzleMetrics: + single_block_metrics = [ + _parse_single_block_replacement_metrics(metrics_path) + for metrics_path in single_block_replacement_validation_dir.glob("*solution*.json") + ] + all_metric_names = tuple(single_block_metrics[0]["metrics"].keys()) + teacher_metrics = _parse_teacher_block_metrics( + single_block_replacement_validation_dir, all_metric_names + ) + + n_layer = len(teacher_metrics) + gathered_metrics = {f"block_{block_idx}": dict() for block_idx in range(n_layer)} + for variant_metrics in single_block_metrics + teacher_metrics: + block_config = variant_metrics["block_config"] + block_name = f"block_{variant_metrics['block_idx']}" + # if we explicitly measure teacher's blocks don't override them + gathered_metrics[block_name][block_config] = variant_metrics + # if not gathered_metrics[block_name].get(block_config): + # gathered_metrics[block_name][block_config] = variant_metrics + return gathered_metrics + + +def gather_multi_layer_puzle_metrics( + single_replacement_validation_dir: Path, +) -> MultiLayerPuzzleMetrics: + single_sequence_metrics = [ + _parse_single_sequence_replacement_metrics(metrics_path) + for metrics_path in single_replacement_validation_dir.glob("*solution*.json") + ] + all_metric_names = tuple(single_sequence_metrics[0]["metrics"].keys()) + teacher_metrics = _parse_teacher_block_metrics( + single_replacement_validation_dir, all_metric_names + ) + + gathered_metrics = { + f"replacement_{replacement_id}": replacement_metrics + for replacement_id, replacement_metrics in enumerate( + single_sequence_metrics + teacher_metrics + ) + } + + return gathered_metrics + + +def _parse_single_block_replacement_metrics(metrics_path: Path) -> dict: + raw_metrics = json.loads(metrics_path.read_text()) + single_block_replacement = raw_metrics["puzzle_solution"]["single_block_replacement"] + variant_metrics = { + "block_config": BlockConfig(**single_block_replacement["block_config"]), + "block_idx": single_block_replacement["block_idx"], + "metrics": _extract_average_metrics(raw_metrics), + } + return variant_metrics + + +def _parse_single_sequence_replacement_metrics(metrics_path: Path) -> dict: + raw_metrics = json.loads(metrics_path.read_text()) + single_sequence_replacement = raw_metrics["puzzle_solution"]["single_sequence_replacement"] + if len(single_sequence_replacement["child_block_configs"]) > 1: + raise NotImplementedError( + "Currently we only support many-to-1 replacements, but we can support many-to-many! " + ) + variant_metrics = { + "block_config": BlockConfig(**single_sequence_replacement["child_block_configs"][0]), + # is there cases where child_block_configs has more than one entry? + "parent_layer_indices": single_sequence_replacement["parent_layer_indices"], + "metrics": _extract_average_metrics(raw_metrics), + "layer_replacement": parse_layer_replacement(single_sequence_replacement), + "is_teacher": False, + } + return variant_metrics + + +def _parse_teacher_block_metrics( + single_block_replacement_validation_dir: Path, + all_metric_names: Iterable[str] = ("kl_div_loss",), +) -> list[dict]: + raw_metrics = json.loads((single_block_replacement_validation_dir / "teacher.json").read_text()) + teacher_checkpoint_dir = Path(raw_metrics["args"]["teacher_dir"]).resolve() + teacher_model_config = load_model_config(teacher_checkpoint_dir) + + teacher_replacements = None + replacement_library_path = raw_metrics["args"].get("replacement_library_path") + if replacement_library_path is not None: + teacher_replacements = dict() + all_layer_replacements = json.loads(Path(replacement_library_path).read_text()) + for layer_replacement in all_layer_replacements: + layer_replacement = parse_layer_replacement(layer_replacement) + if replacement_is_teacher( + layer_replacement, teacher_model_config, teacher_checkpoint_dir + ): + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_replacements[block_idx] = layer_replacement + + teacher_metrics = [ + { + "block_config": block_config, + "block_idx": block_idx, + "parent_layer_indices": [block_idx], + "metrics": { + **{ + metric_name: 0.0 for metric_name in all_metric_names + }, # default value 0. for teacher + **_extract_average_metrics(raw_metrics), # override with real value if exists + }, + **( + {"layer_replacement": teacher_replacements[block_idx]} + if teacher_replacements is not None + else {} + ), + "is_teacher": True, + } + for block_idx, block_config in enumerate(teacher_model_config.block_configs) + ] + return teacher_metrics + + +def _extract_average_metrics(raw_metrics: dict[str, dict]) -> dict[str, float]: + average_metrics = dict() + for metric_name in raw_metrics.keys(): + metric_dict = raw_metrics[metric_name] + if isinstance(metric_dict, dict) and ("avg" in metric_dict.keys()): + metric_value = raw_metrics[metric_name]["avg"] + average_metrics[metric_name] = metric_value + average_metrics[f"one_minus_{metric_name}"] = 1 - metric_value + return average_metrics + + +def filter_subblock_stats_by_args( + all_subblock_stats: list[dict], + subblock_stats_args: dict[str, Any], + convert_dicts_to_dataclasses: bool = True, +) -> dict[str, dict]: + matching_subblock_stats = [ + subblock_stats + for subblock_stats in all_subblock_stats + if _dict_is_subset(subblock_stats_args, subblock_stats["args"]) + ] + assert len(matching_subblock_stats) == 1, ( + "The provided subblock_stats_args should match exactly one measurement " + f"scenario, instead matched {len(matching_subblock_stats)}:\n" + f"{[m['args'] for m in matching_subblock_stats]}" + ) + subblock_stats = deepcopy(matching_subblock_stats[0]) + + if convert_dicts_to_dataclasses: + class_name_to_class = {klass.__name__: klass for klass in [AttentionConfig, FFNConfig]} + subblocks_dict = dict() + for substats in subblock_stats["subblocks"]: + subblock_config_class = class_name_to_class[substats.pop("subblock_config_class")] + subblock_config = subblock_config_class(**substats.pop("subblock_config")) + dict_key = (subblock_config, None) + if "parent_layer_index" in substats: + dict_key = (subblock_config, substats["parent_layer_index"]) + subblocks_dict[dict_key] = substats + subblock_stats["subblocks"] = subblocks_dict + return subblock_stats + + +def _dict_is_subset(dict1: dict, dict2: dict) -> bool: + return all(item in dict2.items() for item in dict1.items()) + + +def _add_block_stats_to_gathered_metrics( + gathered_metrics: PuzzleMetrics, subblock_stats: dict +) -> None: + for block_name, block_variants in gathered_metrics.items(): + parent_layer_index = None + if "parent_layer_indices" in block_variants: + parent_layer_index = block_variants["parent_layer_indices"][0] + + if "metrics" in block_variants: + # this is a sequence stats object for multi-layer puzzle + block_variants["stats"] = _get_block_stats( + subblock_stats, block_variants["block_config"], parent_layer_index + ) + else: + for block_config, variant_metrics in block_variants.items(): + variant_metrics["stats"] = _get_block_stats(subblock_stats, block_config) + + +def _get_block_stats( + subblock_stats: dict, + block_config: BlockConfig | Literal["non_block"], + parent_layer_index: int = None, +) -> dict[str, float]: + if block_config == "non_block": + return subblock_stats["non_block"] + + if block_config.parallel_blocks is None: + attention_key = (block_config.attention, parent_layer_index) + ffn_key = (block_config.ffn, parent_layer_index) + attention_stats = subblock_stats["subblocks"][attention_key] + ffn_stats = subblock_stats["subblocks"][ffn_key] + assert set(attention_stats.keys()) == set(ffn_stats.keys()) + + block_stats = dict() + for k in attention_stats.keys(): + block_stats[k] = _none_add(attention_stats[k], ffn_stats[k]) + block_stats[f"attention_{k}"] = attention_stats[k] + block_stats[f"ffn_{k}"] = ffn_stats[k] + + block_stats["has_attention"] = int( + not block_config.attention.no_op and block_config.attention.mamba is None + ) + block_stats["has_ffn"] = int(not block_config.ffn.no_op) + block_stats["has_moe"] = int(block_config.ffn.moe is not None) + block_stats["not_no_op"] = int( + not (block_config.attention.no_op and block_config.ffn.no_op) + ) + block_stats["num_kv_heads"] = ( + subblock_stats["args"]["n_head"] // block_config.attention.n_heads_in_group + if block_stats["has_attention"] + else 0 + ) + block_stats["num_local_experts"] = ( + block_config.ffn.moe.num_local_experts if block_stats["has_moe"] else 0 + ) + + return block_stats + + # this is a parallel block + ADDITIVE_METRICS = ("memory_mib", "num_params", "kv_cache_memory_mib") + ADDITIVE_METRICS = [ + f"{prefix}{metric}" for prefix in ("", "attention_", "ffn_") for metric in ADDITIVE_METRICS + ] + block_stats = [ + _get_block_stats(subblock_stats, sub_parallel) + for sub_parallel in block_config.parallel_blocks + ] + block_stats = { + k: _none_add_list([sub_parallel_stat[k] for sub_parallel_stat in block_stats]) + if k in ADDITIVE_METRICS + else _none_max_list([sub_parallel_stat[k] for sub_parallel_stat in block_stats]) + for k in block_stats[0].keys() + } + + return block_stats + + +def _none_add(a: float | int | None, b: float | int | None) -> float | int | None: + if a is None or b is None: + return None + return a + b + + +def _none_max(a: float | int | None, b: float | int | None) -> float | int | None: + if a is None or b is None: + return None + return max(a, b) + + +def _none_add_list(l) -> float | int | None: + r = l[0] + if len(l) == 1: + return r + for e in l[1:]: + r = _none_add(r, e) + return r + + +def _none_max_list(l) -> float | int | None: + r = l[0] + if len(l) == 1: + return r + for e in l[1:]: + r = _none_max(r, e) + return r + + +if __name__ == "__main__": + args = parse_args() + run_puzzle(args) diff --git a/modelopt/torch/_compress/mip/utils.py b/modelopt/torch/_compress/mip/utils.py new file mode 100644 index 0000000000..7398203cc2 --- /dev/null +++ b/modelopt/torch/_compress/mip/utils.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for MIP optimization.""" + +from typing import Any + + +class InfeasibleError(Exception): + """Exception raised when optimization problem is infeasible.""" + + pass + + +def sort_replacements(layer_replacements: list[dict]) -> list[dict]: + """Sort layer replacements by parent layer indices. + + Args: + layer_replacements: List of replacement dictionaries containing "parent_layer_indices" + + Returns: + Sorted list of replacements + """ + return sorted(layer_replacements, key=lambda replacement: replacement["parent_layer_indices"]) + + +def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any: + """Access nested dictionary values using dot notation. + + If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"] + + Args: + dictionary: Dictionary to access + nested_key: Dot-separated key path (e.g., "a.b.c") + + Returns: + Value at the nested key location + """ + value = dictionary + for key in nested_key.split("."): + value = value[key] + return value + + +def consecutive_ngrams(l: int, n: int) -> list[list[int]]: + """Generate all consecutive n-grams from range(l). + + Splits range(l) into all consecutive n-grams. + + Args: + l: Length of the range + n: Size of each n-gram + + Returns: + List of consecutive n-grams + + Example: + consecutive_ngrams(7, 2) = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6]] + """ + ngrams = [] + for i in range(l - n + 1): + ngrams.append(list(range(i, i + n))) + return ngrams diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 390ba835a7..5c08c693a2 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -23,10 +23,10 @@ import datetime from pathlib import Path -import mip_and_realize_models import torch from torch import nn +import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch._compress.scoring.scoring as scoring from modelopt.torch._compress import build_library_and_stats diff --git a/modelopt/torch/_compress/replacement_library/replacement_utils.py b/modelopt/torch/_compress/replacement_library/replacement_utils.py index 21ae411752..331357d2bb 100644 --- a/modelopt/torch/_compress/replacement_library/replacement_utils.py +++ b/modelopt/torch/_compress/replacement_library/replacement_utils.py @@ -24,6 +24,7 @@ from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.mip.utils import sort_replacements def parse_layer_replacement(layer_replacement: dict | str) -> dict: @@ -43,8 +44,7 @@ def parse_layer_replacement(layer_replacement: dict | str) -> dict: return layer_replacement -def sort_replacements(layer_replacements: list[dict]) -> list[dict]: - return sorted(layer_replacements, key=lambda replacement: replacement["parent_layer_indices"]) +# sort_replacements moved to modelopt.torch._compress.mip.utils and imported above def extract_block_configs_and_locations( diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py index 7acb3f3274..d03ea80403 100644 --- a/modelopt/torch/_compress/utils/utils.py +++ b/modelopt/torch/_compress/utils/utils.py @@ -14,6 +14,9 @@ # limitations under the License. import dataclasses +import json +import os +from copy import deepcopy from typing import Any import torch @@ -75,6 +78,49 @@ def sizeof_dtype(dtype: torch.dtype | str) -> int | float: return torch.tensor([], dtype=dtype).element_size() +def load_json(file_path: str): + """Load and parse a JSON file. + + TODO: Consider a better place for this function. + + Args: + file_path: Path to the JSON file to load. + + Returns: + Parsed JSON data as a Python object, or None if the file doesn't exist. + """ + if not os.path.exists(file_path): + print("file does not exist {file_path}") + return None + + with open(file=file_path) as f: + return json.load(f) + + +def solution_to_str(block_configs: list[dict[str, Any] | BlockConfig]) -> str: + """Convert a list of block configurations to a human-readable string representation. + + TODO: Consider a better place for this function. + Better place for this and subsequent related function would be in __repr__ function in class + BlockConfig so when we print it or do str(block_config), it automatically + prints in this custom formatted string + + Args: + block_configs: List of BlockConfig dataclasses or dicts containing layer configurations. + + Returns: + Multi-line string with each block's configuration on a separate line. + """ + block_configs = deepcopy(block_configs) + reps = [] + for block_idx, block_config in enumerate(block_configs): + rep = f"block_{block_idx}:".ljust(9) + rep += block_config_to_str(block_config) + reps.append(rep) + rep = "\n".join(reps) + "\n" + return rep + + def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: """ Convert a BlockConfig to a human-readable string representation. diff --git a/setup.py b/setup.py index 3eb41967d1..1b85b41aa1 100644 --- a/setup.py +++ b/setup.py @@ -110,6 +110,7 @@ "typeguard", "pandas", "immutabledict", + "mip", ], } diff --git a/tests/experimental/torch/_compress/compress_test_utils.py b/tests/gpu/torch/_compress/compress_test_utils.py similarity index 98% rename from tests/experimental/torch/_compress/compress_test_utils.py rename to tests/gpu/torch/_compress/compress_test_utils.py index ce22e1864c..a1102e7fac 100644 --- a/tests/experimental/torch/_compress/compress_test_utils.py +++ b/tests/gpu/torch/_compress/compress_test_utils.py @@ -118,7 +118,7 @@ def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: """ Create a tokenizer for the Llama model. """ - tokenizer_path = project_root_path / "tests/experimental/torch/_compress/resources/tokenizer" + tokenizer_path = project_root_path / "tests/gpu/torch/_compress/resources/tokenizer" tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) return tokenizer diff --git a/tests/experimental/torch/_compress/conftest.py b/tests/gpu/torch/_compress/conftest.py similarity index 100% rename from tests/experimental/torch/_compress/conftest.py rename to tests/gpu/torch/_compress/conftest.py diff --git a/tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py b/tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py index 1f0283b3e8..7576f270b3 100644 --- a/tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py +++ b/tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py @@ -16,7 +16,7 @@ import json from pathlib import Path -from experimental.torch._compress.compress_test_utils import ( +from gpu.torch._compress.compress_test_utils import ( create_and_save_small_llama_model, create_tokenizer, ) diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py b/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py similarity index 92% rename from tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py rename to tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py index cf284cfc87..dbbcbacd47 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py @@ -20,7 +20,7 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from experimental.torch._compress.compress_test_utils import setup_test_model_and_data +from gpu.torch._compress.compress_test_utils import setup_test_model_and_data import modelopt.torch.nas as mtn from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel @@ -28,7 +28,7 @@ # -# See tests/experimental/torch/_compress/test_compress.py for instructions on how to run this test +# See tests/gpu/torch/_compress/test_compress.py for instructions on how to run this test # TODO: Remove those instructions once this test runs automatically on CI # def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): @@ -49,9 +49,7 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, runtime ) - hydra_config_dir = ( - project_root_path / "tests/experimental/torch/_compress/resources/configs" - ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" hydra_config_name = "Llama-3_1-8B-ffn-pruning" # @@ -111,9 +109,7 @@ def _test_nas_convert_attn_pruning_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, runtime ) - hydra_config_dir = ( - project_root_path / "tests/experimental/torch/_compress/resources/configs" - ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" hydra_config_name = "Llama-3_1-8B-attn-pruning" # diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py b/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py similarity index 92% rename from tests/experimental/torch/_compress/nas/plugins/test_nas_search.py rename to tests/gpu/torch/_compress/nas/plugins/test_nas_search.py index 4a6a3eccec..e8ea24ecee 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py @@ -14,7 +14,7 @@ # limitations under the License. # -# See tests/experimental/torch/_compress/test_compress.py for instructions on how to run this test +# See tests/gpu/torch/_compress/test_compress.py for instructions on how to run this test # TODO: Remove those instructions once this test runs automatically on CI # import datetime @@ -23,7 +23,7 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from experimental.torch._compress.compress_test_utils import setup_test_model_and_data +from gpu.torch._compress.compress_test_utils import setup_test_model_and_data import modelopt.torch.nas as mtn from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel @@ -48,9 +48,7 @@ def _test_nas_search_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, runtime ) - hydra_config_dir = ( - project_root_path / "tests/experimental/torch/_compress/resources/configs" - ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" hydra_config_name = "Llama-3_1-8B-ffn-pruning" # diff --git a/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml b/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml rename to tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml diff --git a/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml b/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml rename to tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml b/tests/gpu/torch/_compress/resources/configs/pruning/attn_pruning.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml rename to tests/gpu/torch/_compress/resources/configs/pruning/attn_pruning.yaml diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml b/tests/gpu/torch/_compress/resources/configs/pruning/ffn_pruning.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml rename to tests/gpu/torch/_compress/resources/configs/pruning/ffn_pruning.yaml diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml b/tests/gpu/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml rename to tests/gpu/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml b/tests/gpu/torch/_compress/resources/configs/pruning/pruning_defaults.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml rename to tests/gpu/torch/_compress/resources/configs/pruning/pruning_defaults.yaml diff --git a/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml b/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml rename to tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml diff --git a/tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml b/tests/gpu/torch/_compress/resources/configs/validate_solutions_defaults.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml rename to tests/gpu/torch/_compress/resources/configs/validate_solutions_defaults.yaml diff --git a/tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json b/tests/gpu/torch/_compress/resources/tokenizer/special_tokens_map.json similarity index 100% rename from tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json rename to tests/gpu/torch/_compress/resources/tokenizer/special_tokens_map.json diff --git a/tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json b/tests/gpu/torch/_compress/resources/tokenizer/tokenizer.json similarity index 100% rename from tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json rename to tests/gpu/torch/_compress/resources/tokenizer/tokenizer.json diff --git a/tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json b/tests/gpu/torch/_compress/resources/tokenizer/tokenizer_config.json similarity index 100% rename from tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json rename to tests/gpu/torch/_compress/resources/tokenizer/tokenizer_config.json diff --git a/tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py b/tests/gpu/torch/_compress/resources/tokenizer/truncate_tokenizer.py similarity index 100% rename from tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py rename to tests/gpu/torch/_compress/resources/tokenizer/truncate_tokenizer.py diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/gpu/torch/_compress/test_compress.py similarity index 93% rename from tests/experimental/torch/_compress/test_compress.py rename to tests/gpu/torch/_compress/test_compress.py index 76407bc1f0..b00be24857 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/gpu/torch/_compress/test_compress.py @@ -20,7 +20,7 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from experimental.torch._compress.compress_test_utils import setup_test_model_and_data +from gpu.torch._compress.compress_test_utils import setup_test_model_and_data from modelopt.torch._compress import compress from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( @@ -45,7 +45,7 @@ # # export PYTHONPATH=$PYTHONPATH:.:/workspace/puzzletron/v1 # -# pytest -s -v ./tests/experimental/torch/_compress/test_compress.py::test_compress -o addopts="" +# pytest -s -v ./tests/gpu/torch/_compress/test_compress.py::test_compress -o addopts="" def test_compress(project_root_path: Path, tmp_path: Path): @@ -64,9 +64,7 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, runtime ) - hydra_config_dir = ( - project_root_path / "tests/experimental/torch/_compress/resources/configs" - ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" hydra_config_name = "Llama-3_1-8B-ffn-pruning" # Convert the Llama model to DeciLM model. From a99f503ad427fa08c3d074fba44f2d221991f8f7 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 8 Dec 2025 20:59:13 +0530 Subject: [PATCH 23/62] Remove unused mip functions + fix multi-gpu test (#660) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? **Type of change:** Improvement - Fix tests for 2-gpu: Some places hard-coded cpu device for distributed communications which was causing this issue - Remove unused constrain_search_space.py - Remove `is_multi_layer_puzzle: False` case - Remove `use_greedy_search: False` case - Remove knapsack mip case - Remove unused `num_solutions` and `minimal_diversity` flags ## Testing - GH CICD test passing - Tested on 2-gpu setup locally as well ## Summary by CodeRabbit # Release Notes * **Refactor** * Optimized solver implementation with improved library integration. * Simplified model compression configuration by removing deprecated search options. * Consolidated optimization paths for streamlined processing. * **Chores** * Updated dependencies for improved compatibility. * **Documentation** * Clarified Model-Optimizer installation instructions in examples. ✏️ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/compress/README.md | 4 +- .../Llama-3_1-8B.yaml | 15 +- examples/pruning/README.md | 2 +- .../_compress/mip/constrain_search_space.py | 407 ------------------ ...dy_search_with_multi_layer_replacements.py | 180 -------- .../torch/_compress/mip/grouped_knapsack.py | 231 ---------- .../_compress/mip/mip_and_realize_models.py | 13 +- .../mip/mip_with_multi_layer_replacements.py | 7 +- modelopt/torch/_compress/mip/run_puzzle.py | 118 +---- modelopt/torch/_compress/sewing_kit/utils.py | 27 +- setup.py | 10 +- .../torch/_compress/compress_test_utils.py | 10 +- .../configs/Llama-3_1-8B-attn-pruning.yaml | 15 +- .../configs/Llama-3_1-8B-ffn-pruning.yaml | 15 +- tests/gpu/torch/_compress/test_compress.py | 14 - 15 files changed, 82 insertions(+), 986 deletions(-) delete mode 100644 modelopt/torch/_compress/mip/constrain_search_space.py delete mode 100644 modelopt/torch/_compress/mip/greedy_search_with_multi_layer_replacements.py delete mode 100644 modelopt/torch/_compress/mip/grouped_knapsack.py diff --git a/examples/compress/README.md b/examples/compress/README.md index 3bd218aa48..755b6090e8 100644 --- a/examples/compress/README.md +++ b/examples/compress/README.md @@ -13,7 +13,7 @@ In this example, we compress the [meta-llama/Llama-3.1-8B-Instruct](https://hugg ## Environment -- Install TensorRT-Model-Optimizer in editable mode with the corresponding dependencies: +- Install Model-Optimizer in editable mode with the corresponding dependencies: ```bash pip install -e .[hf,compress] @@ -94,7 +94,7 @@ pip install -e .[hf,compress] block_29: attention gqa_4 ffn intermediate_14336 block_30: attention gqa_4 ffn intermediate_14336 block_31: attention gqa_4 ffn intermediate_14336 - + [2025-11-02 04:53:11,332]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 75796.4140625, 'stats.ffn_num_params': 5637275648, 'stats.num_kv_heads': 160, 'stats.kv_cache_memory_mib': 61440.0, 'stats.ffn_memory_mib': 10752.25, 'stats.attention_memory_mib': 63040.15625, 'stats.attention_num_params': 838942720, 'stats.num_params': 7526895616, 'stats.has_attention': 20, 'stats.has_ffn': 32} ... ################################################################ diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml index 70b5304c5b..133fe0b777 100644 --- a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -9,7 +9,7 @@ defaults: puzzle_dir: ??? teacher_dir: ${puzzle_dir}/ckpts/teacher/ replacement_library_path: ${puzzle_dir}/replacement_library.json -dataset_path: ??? # path to v0.4_mini +dataset_path: ??? # path to v0.4_mini skip_realize_model: false @@ -21,10 +21,10 @@ calc_subblock_stats: batch_sizes: [64, 96, 128] prefill_seq_len: 4096 generation_seq_len: 4096 - num_active_tokens_override: # Optional override for sequence lengths + num_active_tokens_override: # Optional override for sequence lengths prefill_queue_size: 0 allocate_prefill_query: false - benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking merge_with_existing_stats: false subblock_stats_filename: "subblock_stats.json" moe_stats_filename: "moe_stats.json" @@ -56,8 +56,6 @@ mip: # puzzle_profile: objective: metrics.cosine_embedding_loss_hidden_states bigger_is_better: false - num_solutions: 1 - minimal_diversity: 2 subblock_stats_args: - batch_size: 96 @@ -81,10 +79,7 @@ mip: target_memory: 78_000 mip_constraints: - use_greedy_search: false - is_multi_layer_puzzle: true metric_overrides: - constrain_search_func: max_seconds_per_solution: 60 realize_model: @@ -92,10 +87,10 @@ realize_model: tokenizer_name: ${to_path:${teacher_dir}} replacement_library_path: ${replacement_library_path} save_models: true - solutions_path: # Filled dynamically + solutions_path: # Filled dynamically # Validate params - skip_validation: false # To enable validation of the model solution set `skip_validation` as False + skip_validation: false # To enable validation of the model solution set `skip_validation` as False eval_samples: 128 micro_batch_size: 1 seed: 42 diff --git a/examples/pruning/README.md b/examples/pruning/README.md index bbc0e7bdea..9e5188e623 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -23,7 +23,7 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar -For more advanced pruning strategies, such as the [Puzzle methodology](https://arxiv.org/pdf/2411.19146), please see [Puzzle pruning example](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/feature/compress/examples/compress). +For more advanced pruning strategies, such as the [Puzzle methodology](https://arxiv.org/pdf/2411.19146), please see [Puzzle pruning example](../compress/README.md). ## Pre-Requisites diff --git a/modelopt/torch/_compress/mip/constrain_search_space.py b/modelopt/torch/_compress/mip/constrain_search_space.py deleted file mode 100644 index e30ee24783..0000000000 --- a/modelopt/torch/_compress/mip/constrain_search_space.py +++ /dev/null @@ -1,407 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Constrains the search space for the MIP optimization.""" - -import traceback - -from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( - AttentionConfig, - BlockConfig, - FFNConfig, -) -from modelopt.torch._compress.utils.utils import load_json - - -def drop_attentions_only(gathered_metrics, teacher_intermediate_size, teacher_n_heads_in_group): - """ - changes the search space such that puzzle is not allowed to change the ffns - but is only allowed to drop or reduce attention. - - Usage example: - add the following flags to your run_puzzle command: - - --constrain_search_func drop_attentions_only --constrain_search_args {\"teacher_intermediate_size\": 14336, \"teacher_n_heads_in_group\": 16, \"above_layer\": 60} - - """ - - for block_name, block_variants in gathered_metrics.items(): - to_delete = [] # Collect keys to delete after the loop - for variant_config, variant_metrics in block_variants.items(): - block_intermediate_size = variant_config.ffn.intermediate_size - block_attn_n_heads = variant_config.attention.n_heads_in_group - if ( - ( - block_intermediate_size is not None - and block_intermediate_size != teacher_intermediate_size - ) - or variant_config.ffn.replace_with_linear - or variant_config.ffn.no_op ## uncomment this line if you want to drop only attns - or variant_config.attention.replace_with_linear - or ( - block_attn_n_heads is not None - and block_attn_n_heads != teacher_n_heads_in_group - ) - ): - print(f"Marking for deletion: {block_name}-{variant_config}") - to_delete.append(variant_config) - for key in to_delete: - del block_variants[key] - - print("new search space in block 0", gathered_metrics["block_0"]) - return gathered_metrics - - -def reduce_only_ffns( - gathered_metrics, - teacher_intermediate_size: int, - teacher_n_heads_in_group: int, - above_layer: int, - allow_no_ops: bool, -): - """ - only allows to reduce FFNs but not to completely drop them from layer 60 onwards - attention is only allowed to be like uniform teacher - - Usage example: - add the following flags to your run_puzzle command: - constrain_search_args='{"teacher_intermediate_size": 14336, "teacher_n_heads_in_group": 16, "above_layer": 60, "allow_no_ops": false}' - - sbatch puzzle/cli/run_puzzle ... --constrain_search_func reduce_only_ffns --constrain_search_args="$(echo "$constrain_search_args" | jq -c .)" - """ - print(f"{teacher_n_heads_in_group=}") - for block_name, block_variants in gathered_metrics.items(): - to_delete = [] # Collect keys to delete after the loop - block_id = int(block_name.split("_")[1]) - - for variant_config, variant_metrics in block_variants.items(): - block_intermediate_size = variant_config.ffn.intermediate_size - block_attn_n_heads = variant_config.attention.n_heads_in_group - - attn_no_op = variant_config.attention.no_op - attn_linear = variant_config.attention.replace_with_linear - if ( - attn_no_op - or attn_linear - or (block_attn_n_heads != teacher_n_heads_in_group) # keep attention as the teacher - or ( - block_id <= above_layer - and (block_intermediate_size != teacher_intermediate_size) - ) - or ((not allow_no_ops) and variant_config.ffn.no_op) - ): - # print(f"Marking for deletion: {block_name}-{variant_config}") - to_delete.append(variant_config) # Add key to delete list - - for key in to_delete: - del block_variants[key] - - print("new search space in block 0", gathered_metrics["block_0"]) - return gathered_metrics - - -def drop_entire_blocks_only(gathered_metrics): - teacher_block_config = _infer_teacher_config(gathered_metrics) - for block_name, block_variants in gathered_metrics.items(): - to_delete = [] # Collect keys to delete after the loop - for variant_config, variant_metrics in block_variants.items(): - is_no_op_block = ( - variant_config.ffn.no_op - and variant_config.attention.no_op - and getattr(variant_config, "parallel_blocks", None) is None - ) - is_teacher = variant_config == teacher_block_config - if not is_no_op_block and not is_teacher: - to_delete.append(variant_config) - for key in to_delete: - del block_variants[key] - - print("new search space in block 0", gathered_metrics["block_0"]) - return gathered_metrics - - -def css_to_reference_attention(gathered_metrics, attention_pruned_arch): - """ - given a reference architecture we fix the search space to only include options that change the FFNs - but to never change the Attentions from the reference arch's Attentions. - """ - - attention_pruned_arch = load_json(attention_pruned_arch)[0] - attention_dropped_blocks = [ - block_name - for block_name, block_config in attention_pruned_arch["chosen_items"].items() - if block_config["attention"]["no_op"] - ] - - for block_name, block_variants in gathered_metrics.items(): - to_delete = [] # Collect keys to delete after the loop - for variant_config, _ in block_variants.items(): - # Uncomment and adjust this block if needed - # does drop only attention - block_attn_n_heads = variant_config.attention.n_heads_in_group - - reference_arch_attn = attention_pruned_arch["chosen_items"][block_name]["attention"][ - "n_heads_in_group" - ] - if ( # we reduce the search space by keeping the reference arch attention as is - (block_name in attention_dropped_blocks and not variant_config.attention.no_op) - or ( - block_name not in attention_dropped_blocks - and block_attn_n_heads != reference_arch_attn - ) - ): - print(f"Marking for deletion: {block_name}-{variant_config}") - to_delete.append(variant_config) - - # Delete marked keys outside the loop - for key in to_delete: - del block_variants[key] - - print("new search space in block 0", gathered_metrics["block_0"]) - return gathered_metrics - - -def css_to_reference_ffn(gathered_metrics, ffn_pruned_arch, allow_linear_attn=True): - """ - given a reference architecture we fix the search space to only include options that change the Attentions - but to never change the FFNs from the reference arch's FFNs. - """ - - ffn_pruned_arch = load_json(ffn_pruned_arch)[0] - - for block_name, block_variants in gathered_metrics.items(): - to_delete = [] # Collect keys to delete after the loop - for variant_config, _ in block_variants.items(): - block_ffn = variant_config.ffn - is_linear_attn = variant_config.attention.replace_with_linear - - reference_arch_ffn = ffn_pruned_arch["chosen_items"][block_name]["ffn"] - reference_arch_ffn = FFNConfig(**reference_arch_ffn) - - if ( # we reduce the search space by keeping the reference arch ffn as is - (block_ffn != reference_arch_ffn) or (not allow_linear_attn and is_linear_attn) - ): - # print(f"Marking for deletion: {block_name}-{variant_config}") - to_delete.append(variant_config) - - # Delete marked keys outside the loop - for key in to_delete: - del block_variants[key] - - print("new search space in block 0", gathered_metrics["block_0"]) - return gathered_metrics - - -def avoid_variable_gqa( - gathered_metrics, - allow_no_op_attn: bool = True, - allow_linear_attn: bool = False, - target_n_heads_in_group: int = None, -): - """ - Allow only the teacher n_heads_in_group, - and optionally also attention no-op (default allow) - and attention linear (default avoid). - - This reducer affects only the attention layers: FFNs are allowed their entire search space. - """ - is_multi_layer_puzzle = is_replacement_gathered_metrics(gathered_metrics) - if is_multi_layer_puzzle: - teacher_block_config = infer_teacher_replacement_config(gathered_metrics) - else: - teacher_block_config = _infer_teacher_config(gathered_metrics) - - if target_n_heads_in_group is None: - target_n_heads_in_group = teacher_block_config.attention.n_heads_in_group - - if not is_multi_layer_puzzle: - for block_name, block_variants in gathered_metrics.items(): - to_delete = [] # Collect keys to delete after the loop - - for variant_config, variant_metrics in block_variants.items(): - if not ( - (variant_config.attention.n_heads_in_group == target_n_heads_in_group) - or (variant_config.attention.no_op and allow_no_op_attn) - or (variant_config.attention.replace_with_linear and allow_linear_attn) - ): - to_delete.append(variant_config) - - for key in to_delete: - del block_variants[key] - else: - to_delete = [] # Collect keys to delete after the loop - for replacement_id, replacement in gathered_metrics.items(): - variant_config = replacement["block_config"] - if not ( - (variant_config.attention.n_heads_in_group == target_n_heads_in_group) - or (variant_config.attention.no_op and allow_no_op_attn) - or (variant_config.attention.replace_with_linear and allow_linear_attn) - ): - to_delete.append(replacement_id) - - for key in to_delete: - del gathered_metrics[key] - if not is_multi_layer_puzzle: - print("new search space in block 0", gathered_metrics["block_0"]) - else: - parent_layer_idx = 0 - print( - "new search space in block {parent_layer_idx}", - [ - replacement["block_config"] - for replacement_id, replacement in gathered_metrics.items() - if replacement["parent_layer_indices"][0] == parent_layer_idx - ], - ) - return gathered_metrics - - -def reduce_in_range( - gathered_metrics, - layer_start: int, - layer_end: int, -): - """ - Allow only reduction of layers between layer_start and layer_end. Leyers before layers start, and after layer_end are kept as is (the teacher). - - """ - assert layer_start < layer_end, ( - f"Wrong input arguments: {layer_start=} must be less than {layer_end=}" - ) - is_multi_layer_puzzle = is_replacement_gathered_metrics(gathered_metrics) - if is_multi_layer_puzzle: - teacher_block_config = infer_teacher_replacement_config(gathered_metrics) - else: - teacher_block_config = _infer_teacher_config(gathered_metrics) - - to_delete = [] # Collect keys to delete after the loop - for replacement_id, replacement in gathered_metrics.items(): - block_id = max(replacement["parent_layer_indices"]) - variant_config = replacement["block_config"] - is_teacher = variant_config == teacher_block_config - if (block_id < layer_start or block_id > layer_end) and not is_teacher: - to_delete.append(replacement_id) - - for key in to_delete: - del gathered_metrics[key] - - if not is_multi_layer_puzzle: - print("new search space in block 0", gathered_metrics["block_0"]) - else: - parent_layer_idx = 0 - print( - "new search space in block {parent_layer_idx}", - [ - replacement["block_config"] - for replacement_id, replacement in gathered_metrics.items() - if replacement["parent_layer_indices"][0] == parent_layer_idx - ], - ) - return gathered_metrics - - -############################################################################################# - - -# automatically builds a dictionary mapping method names in this module to their functions -# this dictionary is used to dynamically dispatch functions -dispatcher = { - method_name: method_callable - for method_name, method_callable in globals().items() - if callable(method_callable) -} - - -def is_replacement_gathered_metrics(gathered_metrics) -> bool: - # if the gathered metrics is a replacement, then it is a dictionary of the form {'replacement_{id}': replacement_metrics} - - return isinstance(gathered_metrics, dict) and all( - key.startswith("replacement_") for key in gathered_metrics - ) - - -def _infer_teacher_config(gathered_metrics) -> BlockConfig: - n_heads_in_group, intermediate_size = zip( - *[ - (variant_config.attention.n_heads_in_group, variant_config.ffn.intermediate_size) - for block_name, block_variants in gathered_metrics.items() - for variant_config, variant_metrics in block_variants.items() - ] - ) - teacher_n_heads_in_group = min(filter(None, n_heads_in_group)) - teacher_intermediate_size = max(filter(None, intermediate_size)) - - unique_teacher_candidates = set() - for block_name, block_variants in gathered_metrics.items(): - for variant_config, variant_metrics in block_variants.items(): - if ( - variant_config.ffn.intermediate_size == teacher_intermediate_size - and variant_config.attention.n_heads_in_group == teacher_n_heads_in_group - ): - unique_teacher_candidates.add(variant_config) - - assert len(unique_teacher_candidates) == 1, ( - f"Woops, expected example one candidate to be the teacher block config, instead found: {unique_teacher_candidates=}" - ) - - teacher_block_config = unique_teacher_candidates.pop() - return teacher_block_config - - -def infer_teacher_replacement_config(gathered_metrics) -> BlockConfig: - n_heads_in_group, intermediate_size = zip( - *[ - ( - replacement["block_config"].attention.n_heads_in_group, - replacement["block_config"].ffn.intermediate_size, - ) - for replacement_id, replacement in gathered_metrics.items() - ] - ) - teacher_intermediate_size = max(filter(None, intermediate_size)) - teacher_n_heads_in_group = min(filter(None, n_heads_in_group)) - unique_teacher_candidates = set() - for replacement_id, replacement in gathered_metrics.items(): - if ( - replacement["block_config"].ffn.intermediate_size == teacher_intermediate_size - and replacement["block_config"].attention.n_heads_in_group == teacher_n_heads_in_group - ): - unique_teacher_candidates.add(replacement["block_config"]) - - assert len(unique_teacher_candidates) == 1, ( - f"Woops, expected example one candidate to be the teacher block config, instead found: {unique_teacher_candidates=}" - ) - - teacher_replacement_config = unique_teacher_candidates.pop() - return teacher_replacement_config - - -def apply(css_func_name, gathered_metrics, method_kwargs): - search_space_reducer = dispatcher.get(css_func_name) - if search_space_reducer is None: - raise ValueError( - f"could not find a function called `{css_func_name}` in {__name__}.py to reduce search space " - ) - - try: - gathered_metrics = search_space_reducer(gathered_metrics, **method_kwargs) - except Exception as e: - traceback.print_exc() - raise ValueError( - f"something went wrong when trying to apply the following search space reducer `{css_func_name}` \ - with the folloing args: {method_kwargs}, here's the exception: {e}" - ) - - return gathered_metrics diff --git a/modelopt/torch/_compress/mip/greedy_search_with_multi_layer_replacements.py b/modelopt/torch/_compress/mip/greedy_search_with_multi_layer_replacements.py deleted file mode 100644 index 719643cc22..0000000000 --- a/modelopt/torch/_compress/mip/greedy_search_with_multi_layer_replacements.py +++ /dev/null @@ -1,180 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Performs greedy search to find optimal multi-layer replacements under resource constraints.""" - -# mypy: ignore-errors -import math -from copy import deepcopy -from random import random -from typing import Any, Hashable, TypeAlias - -from .utils import InfeasibleError, consecutive_ngrams, get_nested_key, sort_replacements - -ReplacementID: TypeAlias = Hashable -Replacement: TypeAlias = dict[str, Any] -ChosenReplacements: TypeAlias = list[Replacement] - - -def run_greedy_search( - teacher_replacements: list[Replacement], - student_replacements: list[Replacement], - objective: str, - constraints: dict[str, float], - bigger_is_better: bool, -) -> tuple[ChosenReplacements, float, dict[str, float]]: - print("####### running greedy search #######") - teacher_replacements = deepcopy(teacher_replacements) - student_replacements = deepcopy(student_replacements) - chosen_replacements: ChosenReplacements = [] - - teacher_replacements = { - replacement["parent_layer_indices"][0]: replacement for replacement in teacher_replacements - } - - all_parent_layers = set(teacher_replacements.keys()) - uncovered_parent_layers = set(all_parent_layers) - - while True: - if len(student_replacements) == 0: - raise InfeasibleError() - - choice_func = max if bigger_is_better else min - best_replacement = choice_func( - student_replacements, key=lambda replacement: get_nested_key(replacement, objective) - ) - chosen_replacements.append(best_replacement) - uncovered_parent_layers -= set(best_replacement["parent_layer_indices"]) - student_replacements = _filter_overlapping_replacements( - student_replacements, uncovered_parent_layers - ) - - padded_chosen_replacements = list(chosen_replacements) - for uncovered_block_idx in uncovered_parent_layers: - padded_chosen_replacements.append(teacher_replacements[uncovered_block_idx]) - - all_constraints_satisfied = True - for constraint_key, max_cost in constraints.items(): - total_cost = sum( - get_nested_key(replacement, constraint_key) - for replacement in padded_chosen_replacements - ) - is_constraint_satisfied = total_cost < max_cost or math.isclose( - total_cost, max_cost, rel_tol=1e-9 - ) - if not is_constraint_satisfied: - all_constraints_satisfied = False - - if all_constraints_satisfied: - chosen_replacements = padded_chosen_replacements - break - - # Trust But Verify: calculate total value and costs, and check that all the constraints are filled - total_value = 0.0 - total_costs = {constraint_key: 0 for constraint_key in constraints.keys()} - chosen_layers = set() - for replacement in chosen_replacements: - total_value += get_nested_key(replacement, objective) - for constraint_key in constraints.keys(): - total_costs[constraint_key] += get_nested_key(replacement, constraint_key) - for parent_layer_idx in replacement["parent_layer_indices"]: - assert parent_layer_idx not in chosen_layers, ( - f"Found duplicate chosen layer {parent_layer_idx}" - ) - chosen_layers.add(parent_layer_idx) - - missing_layers = all_parent_layers - set(chosen_layers) - assert len(missing_layers) == 0, ( - f"The following layers were not chosen by any replacement:\n{missing_layers=}\n{chosen_replacements}" - ) - - for constraint_key, max_cost in constraints.items(): - assert total_costs[constraint_key] < max_cost or math.isclose( - total_costs[constraint_key], max_cost, rel_tol=1e-9 - ), ( - f"this constraint was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} <= {max_cost=}" - ) - - chosen_replacements = sort_replacements(chosen_replacements) - for cr in chosen_replacements: - if "block_config" in cr: - cr["child_block_configs"] = cr["block_config"] - - return [ - { - "chosen_replacements": chosen_replacements, - "total_value": total_value, - "total_costs": total_costs, - } - ] - - -def _filter_overlapping_replacements( - replacements: list[Replacement], - uncovered_parent_layers: set[int], -) -> list[Replacement]: - return [ - replacement - for replacement in replacements - if set(replacement["parent_layer_indices"]).issubset(uncovered_parent_layers) - ] - - -def usage_example(): - num_layers = 32 - num_options_per_parent_replacement = 5 - - teacher_replacements = [] - student_replacements = [] - for num_layers_in_replacement in (1, 2, 3): - for i_option in range(num_options_per_parent_replacement): - for parent_layer_indices in consecutive_ngrams(num_layers, num_layers_in_replacement): - is_teacher = num_layers_in_replacement == 1 and i_option == 0 - replacement_id = f"parent layers {parent_layer_indices} child config {i_option}" - replacement = { - "parent_layer_indices": parent_layer_indices, - "metrics": {"loss": random() if not is_teacher else 0.0}, - "stats": {"cost": 1}, - "replacement_id": replacement_id, - } - if is_teacher: - teacher_replacements.append(replacement) - else: - student_replacements.append(replacement) - - constraints = {"stats.cost": num_layers - 8} - (result,) = run_greedy_search( - teacher_replacements, - student_replacements, - objective="metrics.loss", - constraints=constraints, - bigger_is_better=False, - ) - chosen_replacements = result["chosen_replacements"] - total_value = result["total_value"] - total_costs = result["total_costs"] - - print() - print() - print(f"{total_value=}") - print(f"{total_costs=}") - print(f"{constraints=}") - print("chosen_replacements=") - print(chosen_replacements) - print("\n".join([rep["replacement_id"] for rep in chosen_replacements])) - - -if __name__ == "__main__": - usage_example() diff --git a/modelopt/torch/_compress/mip/grouped_knapsack.py b/modelopt/torch/_compress/mip/grouped_knapsack.py deleted file mode 100644 index 5769ded3cd..0000000000 --- a/modelopt/torch/_compress/mip/grouped_knapsack.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Solves the grouped knapsack problem using Mixed Integer Programming to find optimal item selections.""" - -# mypy: ignore-errors -import math -import warnings -from copy import deepcopy -from random import random -from typing import Any, Hashable, Iterable, Optional, TypeAlias, Union - -from mip import BINARY, Model, maximize, minimize, xsum -from tqdm import tqdm - -from .utils import InfeasibleError, get_nested_key - -Item: TypeAlias = dict[str, float | dict[str, float]] -Group: TypeAlias = dict[Hashable, Item] -ChosenItems: TypeAlias = dict[Hashable, Hashable] - - -def multi_solution_grouped_knapsack( - groups: dict[Hashable, Group], - objective: str, - constraints: dict[str, float], - bigger_is_better: bool, - num_solutions: int, - minimal_diversity: int = 1, - max_seconds_per_solution: Optional[float] = None, -) -> list[dict[str, Union[ChosenItems, float]]]: - solutions = [] - previous_choices = [] - for i_run in tqdm(range(num_solutions), desc="multi_solution_grouped_knapsack"): - try: - chosen_items, total_value, total_costs = grouped_knapsack( - groups, - objective, - constraints, - bigger_is_better, - previous_choices, - minimal_diversity, - max_seconds_per_solution, - ) - except InfeasibleError: - warnings.warn(f"Found only {i_run} feasible solutions (requested {num_solutions})") - break - previous_choices.append(chosen_items) - solutions.append( - {"chosen_items": chosen_items, "total_value": total_value, "total_costs": total_costs} - ) - return solutions - - -def grouped_knapsack( - groups: dict[Hashable, Group], - objective: str, - constraints: dict[str, float | tuple[float, float]], - bigger_is_better: bool, - previous_choices: Optional[list[ChosenItems]] = None, - minimal_diversity: int = 1, - max_seconds_per_solution: Optional[float] = None, -) -> tuple[ChosenItems, float, dict[str, float]]: - groups = deepcopy(groups) - mip_model = Model() - - objective_vars = [] - constraint_vars = {constraint_key: [] for constraint_key in constraints.keys()} - for group_name, group_items in groups.items(): - group_vars = [] - for item_name, item in group_items.items(): - is_chosen = mip_model.add_var(var_type=BINARY) - item["is_chosen"] = is_chosen - group_vars.append(is_chosen) - objective_vars.append(is_chosen * get_nested_objective(item, objective)) - for constraint_key in constraints.keys(): - constraint_vars[constraint_key].append( - is_chosen * get_nested_key(item, constraint_key) - ) - - mip_model += xsum(group_vars) == 1 - - for constraint_key, max_cost in constraints.items(): - min_cost = None - if isinstance(max_cost, Iterable): - min_cost, max_cost = max_cost - - if max_cost is not None: - mip_model += xsum(constraint_vars[constraint_key]) <= max_cost - if min_cost is not None: - mip_model += xsum(constraint_vars[constraint_key]) >= min_cost - - if previous_choices is not None: - for previous_chosen_items in previous_choices: - corresponding_vars = [ - groups[group_name][item_name]["is_chosen"] - for group_name, item_name in previous_chosen_items.items() - ] - mip_model += xsum(corresponding_vars) <= len(groups) - minimal_diversity - - mip_model.objective = ( - maximize(xsum(objective_vars)) if bigger_is_better else minimize(xsum(objective_vars)) - ) - - if max_seconds_per_solution is not None: - mip_model.max_seconds = max_seconds_per_solution - - mip_model.optimize() - - if is_chosen.x is None: - raise InfeasibleError() - - total_value = 0.0 - total_costs = {constraint_key: 0 for constraint_key in constraints.keys()} - chosen_items: ChosenItems = dict() - for group_name, group_items in groups.items(): - for item_name, item in group_items.items(): - is_chosen = item["is_chosen"].x >= 0.99 - if is_chosen: - assert group_name not in chosen_items - chosen_items[group_name] = item_name - total_value += get_nested_objective(item, objective) - for constraint_key in constraints.keys(): - total_costs[constraint_key] += get_nested_key(item, constraint_key) - - if len(chosen_items) != len(groups): - in_groups_and_not_in_chosen_items = set(groups.keys()) - set(chosen_items.keys()) - in_chosen_items_and_not_in_groups = set(chosen_items.keys()) - set(groups.keys()) - missing_groups = [groups[key] for key in in_groups_and_not_in_chosen_items] - raise RuntimeError(f""" - Different number of 'chosen_items' and 'groups': {len(chosen_items)=} {len(groups)=} - {in_groups_and_not_in_chosen_items=} - {in_chosen_items_and_not_in_groups=} - {missing_groups=} - """) - - for constraint_key, max_cost in constraints.items(): - min_cost = None - if isinstance(max_cost, Iterable): - min_cost, max_cost = max_cost - - if max_cost is not None: - assert total_costs[constraint_key] < max_cost or math.isclose( - total_costs[constraint_key], max_cost, rel_tol=1e-9 - ), ( - f"This max_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} > {max_cost=}" - ) - if min_cost is not None: - assert total_costs[constraint_key] > min_cost or math.isclose( - total_costs[constraint_key], min_cost, rel_tol=1e-9 - ), ( - f"This min_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} < {min_cost=}" - ) - - for previous_chosen_items in previous_choices: - num_differences = 0 - for group_name in groups.keys(): - num_differences += previous_chosen_items[group_name] != chosen_items[group_name] - assert num_differences >= minimal_diversity - - return chosen_items, total_value, total_costs - - -def get_nested_objective(dictionary: dict[str, Any], nested_key: str) -> Any: - if nested_key.startswith("metrics."): - # handle metrics that have '.' in their name - metric = nested_key.split("metrics.")[1] - return dictionary["metrics"][metric] - else: - return get_nested_key(dictionary, nested_key) - - -def usage_example(): - num_layers = 32 - num_configs_per_block = 100 - groups = { - f"layer_{i_layer}": { - f"config_{i_config}": { - "metrics": {"accuracy": random()}, - "stats": {"memory_mib": random() * 100, "runtime_ms": random() * 10}, - } - for i_config in range(num_configs_per_block) - } - for i_layer in range(num_layers) - } - - minimal_diversity = 10 - constraints = {"stats.memory_mib": num_layers * 50.0, "stats.runtime_ms": num_layers * 5.0} - solutions = multi_solution_grouped_knapsack( - groups, - objective="metrics.accuracy", - constraints=constraints, - bigger_is_better=True, - num_solutions=10, - minimal_diversity=minimal_diversity, - ) - - print() - print(constraints) - - for i_run, solution in enumerate(solutions): - print() - print(f"run {i_run}") - print(solution) - - print(f"Checking differences, should be at least {minimal_diversity}:") - for a in range(len(solutions)): - for b in range(a + 1, len(solutions)): - num_differences = 0 - for group_name in groups.keys(): - num_differences += ( - solutions[a]["chosen_items"][group_name] - != solutions[b]["chosen_items"][group_name] - ) - print(a, "<>", b, "=", num_differences) - - -if __name__ == "__main__": - usage_example() diff --git a/modelopt/torch/_compress/mip/mip_and_realize_models.py b/modelopt/torch/_compress/mip/mip_and_realize_models.py index 83d8b23f56..f6d77d2624 100644 --- a/modelopt/torch/_compress/mip/mip_and_realize_models.py +++ b/modelopt/torch/_compress/mip/mip_and_realize_models.py @@ -44,12 +44,19 @@ def launch_realize_model(cfg: DictConfig, runtime: IRuntime): def launch_mip_and_realize_model(cfg: DictConfig, runtime: IRuntime): + # Determine device for distributed operations (NCCL requires CUDA tensors) + device = "cpu" + if runtime.world_size > 1 and dist.is_initialized(): + backend = dist.get_backend() + if backend == "nccl": + device = torch.cuda.current_device() + if runtime.is_main_process: solution_paths = launch_mip(cfg) - length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long) + length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long, device=device) else: solution_paths = None - length_tensor = torch.tensor([0], dtype=torch.long) + length_tensor = torch.tensor([0], dtype=torch.long, device=device) if not cfg.skip_realize_model: if runtime.world_size > 1: @@ -75,7 +82,7 @@ def main(cfg: DictConfig) -> None: cfg = hydra.utils.instantiate(cfg) _runtime = ( - NativeDDP_Runtime( + NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") ) if is_distributed() diff --git a/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py b/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py index 50525c846c..438db3312e 100644 --- a/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py +++ b/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py @@ -25,7 +25,12 @@ from mip import BINARY, Model, maximize, minimize, xsum -from .utils import InfeasibleError, consecutive_ngrams, get_nested_key, sort_replacements +from modelopt.torch._compress.mip.utils import ( + InfeasibleError, + consecutive_ngrams, + get_nested_key, + sort_replacements, +) ReplacementID: TypeAlias = Hashable Replacement: TypeAlias = dict[str, Any] diff --git a/modelopt/torch/_compress/mip/run_puzzle.py b/modelopt/torch/_compress/mip/run_puzzle.py index fd883e969f..5773349c11 100644 --- a/modelopt/torch/_compress/mip/run_puzzle.py +++ b/modelopt/torch/_compress/mip/run_puzzle.py @@ -28,15 +28,11 @@ import yaml from omegaconf import DictConfig, ListConfig, OmegaConf -import modelopt.torch._compress.mip.constrain_search_space as css from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, BlockConfig, FFNConfig, ) -from modelopt.torch._compress.mip.greedy_search_with_multi_layer_replacements import ( - run_greedy_search, -) from modelopt.torch._compress.mip.mip_with_multi_layer_replacements import ( run_mip as run_multi_layer_replacement_mip, ) @@ -211,8 +207,6 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--human_constraints", type=parse_json) parser.add_argument("--report_additional_costs", type=str, action="append", default=[]) - parser.add_argument("--num_solutions", type=int) - parser.add_argument("--minimal_diversity", type=int) parser.add_argument( "--output_path", type=parse_path, @@ -227,21 +221,6 @@ def parse_args() -> argparse.Namespace: help="Set this if using accuracy objective, don't set if using loss objective", ) - parser.add_argument("--constrain_search_func", type=str, default=None) - parser.add_argument("--constrain_search_args", type=parse_json, default=dict()) - - parser.add_argument( - "--is_multi_layer_puzzle", - action="store_true", - default=True, - help="[DEPRECATED] This flag is now always True. Kept for backward compatibility.", - ) - parser.add_argument( - "--use_greedy_search", - action="store_true", - help="Use greedy search instead of mip. Only supported for multi-layer puzzle.", - ) - args = parser.parse_args() return args @@ -254,17 +233,14 @@ def run_single_puzzle_config( constraints: PuzzleConstraints, output_folder, ) -> None: - from modelopt.torch._compress.mip.grouped_knapsack import multi_solution_grouped_knapsack - - args = deepcopy( - args - ) # we override the constraints and subblock_stats_args for this run to keep reporting out the same way. + # we override the constraints and subblock_stats_args for this run to keep reporting out the same way. + args = deepcopy(args) subblock_stats = filter_subblock_stats_by_args(subblock_stats, subblock_stats_args) _add_block_stats_to_gathered_metrics(gathered_metrics, subblock_stats) output_folder.mkdir(parents=True, exist_ok=True) - _dump_gathered_metrics(gathered_metrics, output_folder, args.is_multi_layer_puzzle) + _dump_gathered_metrics(gathered_metrics, output_folder) non_block_stats = {"stats": _get_block_stats(subblock_stats, "non_block")} batch_size = subblock_stats["args"]["batch_size"] @@ -304,40 +280,13 @@ def run_single_puzzle_config( mprint(f"After non-block adjustments: {mip_constraints=}") - if args.is_multi_layer_puzzle: - if not args.use_greedy_search: - solutions = run_multi_layer_replacement_mip( - replacements=gathered_metrics, - objective=args.objective, - constraints=mip_constraints, - bigger_is_better=args.bigger_is_better, - max_seconds_per_solution=args.max_seconds_per_solution, - ) - else: - teacher_replacements, student_replacements = [], [] - for replacement in gathered_metrics.values(): - if replacement["is_teacher"]: - teacher_replacements.append(replacement) - else: - student_replacements.append(replacement) - - solutions = run_greedy_search( - teacher_replacements=teacher_replacements, - student_replacements=student_replacements, - objective=args.objective, - constraints=mip_constraints, - bigger_is_better=args.bigger_is_better, - ) - else: - solutions = multi_solution_grouped_knapsack( - groups=gathered_metrics, - objective=args.objective, - constraints=mip_constraints, - bigger_is_better=args.bigger_is_better, - num_solutions=args.num_solutions, - minimal_diversity=args.minimal_diversity, - max_seconds_per_solution=args.max_seconds_per_solution, - ) + solutions = run_multi_layer_replacement_mip( + replacements=gathered_metrics, + objective=args.objective, + constraints=mip_constraints, + bigger_is_better=args.bigger_is_better, + max_seconds_per_solution=args.max_seconds_per_solution, + ) for solution in solutions: for stat_name in set([*orig_mip_constraints.keys(), *args.report_additional_costs]): @@ -379,25 +328,10 @@ def run_single_puzzle_config( return solutions_file -def _dump_gathered_metrics( - gathered_metrics: PuzzleMetrics, output_folder: Path, is_multi_layer_puzzle: bool = False -) -> None: - if is_multi_layer_puzzle: - for replacement_id, replacement_info in gathered_metrics.items(): - replacement_info["block_repr"] = block_config_to_str(replacement_info["block_config"]) - gathered_metrics_for_dump = gathered_metrics - else: - gathered_metrics_for_dump = { - block_name: { - block_config_to_str(variant_config).strip(): { - **variant_metrics, - "block_config": variant_config, - "block_repr": block_config_to_str(variant_config).strip(), - } - for variant_config, variant_metrics in block_variants.items() - } - for block_name, block_variants in gathered_metrics.items() - } +def _dump_gathered_metrics(gathered_metrics: PuzzleMetrics, output_folder: Path) -> None: + for replacement_id, replacement_info in gathered_metrics.items(): + replacement_info["block_repr"] = block_config_to_str(replacement_info["block_config"]) + gathered_metrics_for_dump = gathered_metrics json_dump(gathered_metrics_for_dump, output_folder / "replacement_metrics_and_stats.json") @@ -451,17 +385,12 @@ def _override_args_from_profile(args, puzzle_profile): if arg_name in puzzle_profile: if arg_name not in ("mip_constraints", "human_constraints", "subblock_stats_args"): setattr(args, arg_name, puzzle_profile[arg_name]) - if isinstance(args.constrain_search_args, str): - args.constrain_search_args = parse_json(args.constrain_search_args) - assert args.is_multi_layer_puzzle, "multi-layer puzzle is now the only supported mode." def _assert_valid_config(args, puzzle_profile): required_args = ( "subblock_stats_path", "objective", - "num_solutions", - "minimal_diversity", "output_path", ) missing_args = [arg for arg in required_args if arg not in args or getattr(args, arg) is None] @@ -488,11 +417,6 @@ def _assert_valid_config(args, puzzle_profile): ) exit(1) - if args.use_greedy_search: - assert args.is_multi_layer_puzzle, ( - "--use_greedy_search is only supported for multi layer puzzle" - ) - def _get_minimal_unique_names(dicts: List[dict]) -> List[str]: all_keys = set(k for d in dicts for k in d.keys()) @@ -517,23 +441,13 @@ def run_puzzle(args: argparse.Namespace) -> List[str]: if args.gathered_metrics_path is not None: gathered_metrics = json.loads(args.gathered_metrics_path.read_text()) else: - gather_func = ( - gather_puzzle_metrics - if not args.is_multi_layer_puzzle - else gather_multi_layer_puzle_metrics + gathered_metrics = gather_multi_layer_puzle_metrics( + args.single_block_replacement_validation_dir ) - gathered_metrics = gather_func(args.single_block_replacement_validation_dir) if args.metric_overrides is not None: gathered_metrics = {**gathered_metrics, **args.metric_overrides} - if args.constrain_search_func is not None: - mprint(f"{args.constrain_search_args=}") - # assert not args.is_multi_layer_puzzle, "conditional search is not implementd yet for multi-layer puzzles, did you implement it?" - gathered_metrics = css.apply( - args.constrain_search_func, gathered_metrics, args.constrain_search_args - ) - subblock_stats = json.loads(args.subblock_stats_path.read_text()) all_subblock_args = _load_all_subblock_stats_args(args, puzzle_profile) diff --git a/modelopt/torch/_compress/sewing_kit/utils.py b/modelopt/torch/_compress/sewing_kit/utils.py index 16fe1b3fd3..ff47c289b6 100644 --- a/modelopt/torch/_compress/sewing_kit/utils.py +++ b/modelopt/torch/_compress/sewing_kit/utils.py @@ -447,13 +447,33 @@ def get_parent_module_names(module_name: str): return parent_module_names +def _get_device_for_distributed( + group: Optional[torch.distributed.ProcessGroup] = None, +) -> str: + """ + Determine the appropriate device for distributed communication based on the backend. + NCCL backend requires CUDA tensors, while Gloo supports both CPU and CUDA. + """ + if not torch.distributed.is_initialized(): + return "cpu" + + backend = torch.distributed.get_backend(group) + if backend == "nccl": + # NCCL requires CUDA tensors + return torch.cuda.current_device() + else: + # Gloo and other backends support CPU tensors + return "cpu" + + def distributed_isend_obj( obj: Any, dst: int = 0, group: Optional[torch.distributed.ProcessGroup] = None, ) -> list[Optional[torch.distributed.Work]]: + device = _get_device_for_distributed(group) obj_tensor, obj_size_tensor = torch.distributed.distributed_c10d._object_to_tensor( - obj, device="cpu", **_get_group_kwarg_if_necessary() + obj, device=device, **_get_group_kwarg_if_necessary() ) works: list[Optional[torch.distributed.Work]] = [ torch.distributed.isend(obj_size_tensor, dst, group), @@ -484,11 +504,12 @@ def distributed_recv_obj( src: Optional[int] = None, group: Optional[torch.distributed.ProcessGroup] = None, ) -> Any: - obj_size_tensor = torch.LongTensor(1, device="cpu") + device = _get_device_for_distributed(group) + obj_size_tensor = torch.LongTensor(1).to(device) torch.distributed.recv(obj_size_tensor, src=src, group=group) obj_size = int(obj_size_tensor.item()) - obj_tensor = torch.ByteTensor(obj_size, device="cpu") + obj_tensor = torch.ByteTensor(obj_size).to(device) torch.distributed.recv(obj_tensor, src=src, group=group) obj = torch.distributed.distributed_c10d._tensor_to_object( diff --git a/setup.py b/setup.py index d4077f709d..20a271fe15 100644 --- a/setup.py +++ b/setup.py @@ -105,13 +105,13 @@ "compress": [ "fire", "hydra-core==1.3.2", - "omegaconf==2.3.0", - "wandb~=0.17.5", - "lru-dict", - "typeguard", - "pandas", "immutabledict", + "lru-dict", "mip", + "omegaconf==2.3.0", + "pandas", + "typeguard", + "wandb~=0.17.5", ], } diff --git a/tests/gpu/torch/_compress/compress_test_utils.py b/tests/gpu/torch/_compress/compress_test_utils.py index a1102e7fac..9df5f5bfcf 100644 --- a/tests/gpu/torch/_compress/compress_test_utils.py +++ b/tests/gpu/torch/_compress/compress_test_utils.py @@ -29,11 +29,7 @@ def setup_test_model_and_data( tmp_path: Path, rank: int, runtime, -) -> tuple[ - Path, - Path, - Path, -]: +) -> tuple[Path, Path, Path]: """ Setup the test model and data for the compress NAS search. @@ -132,7 +128,7 @@ def setup_puzzle_dir(puzzle_dir: str): Path(puzzle_dir).mkdir(parents=True, exist_ok=True) -def save_dummy_dataset(dataset_path: str): +def save_dummy_dataset(dataset_path: Path | str): """ Save a dummy dataset for testing purposes. """ @@ -170,4 +166,4 @@ def save_dummy_dataset(dataset_path: str): # For train-val splits data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) - data_dict.save_to_disk(dataset_path) + data_dict.save_to_disk(str(dataset_path)) diff --git a/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml b/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml index 21a3486f09..473a5d418d 100644 --- a/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml +++ b/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml @@ -9,7 +9,7 @@ defaults: puzzle_dir: ??? teacher_dir: ${puzzle_dir}/ckpts/teacher/ replacement_library_path: ${puzzle_dir}/replacement_library.json -dataset_path: ??? # path to v0.4_mini +dataset_path: ??? # path to v0.4_mini skip_realize_model: false @@ -21,10 +21,10 @@ calc_subblock_stats: batch_sizes: [64, 96, 128] prefill_seq_len: 4096 generation_seq_len: 4096 - num_active_tokens_override: # Optional override for sequence lengths + num_active_tokens_override: # Optional override for sequence lengths prefill_queue_size: 0 allocate_prefill_query: false - benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking merge_with_existing_stats: false subblock_stats_filename: "subblock_stats.json" moe_stats_filename: "moe_stats.json" @@ -54,8 +54,6 @@ mip: # puzzle_profile: objective: metrics.cosine_embedding_loss_hidden_states bigger_is_better: false - num_solutions: 1 - minimal_diversity: 2 subblock_stats_args: - batch_size: 96 @@ -79,10 +77,7 @@ mip: target_memory: 780_000 # 78_000 mip_constraints: - use_greedy_search: false - is_multi_layer_puzzle: true metric_overrides: - constrain_search_func: max_seconds_per_solution: 60 realize_model: @@ -90,10 +85,10 @@ realize_model: tokenizer_name: ${to_path:${teacher_dir}} replacement_library_path: ${replacement_library_path} save_models: true - solutions_path: # Filled dynamically + solutions_path: # Filled dynamically # Validate params - skip_validation: false # To enable validation of the model solution set `skip_validation` as False + skip_validation: false # To enable validation of the model solution set `skip_validation` as False eval_samples: 2 micro_batch_size: 1 dataset_path: ${dataset_path}/valid diff --git a/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml b/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml index 1d8fac655f..8af352660b 100644 --- a/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml +++ b/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml @@ -9,7 +9,7 @@ defaults: puzzle_dir: ??? teacher_dir: ${puzzle_dir}/ckpts/teacher/ replacement_library_path: ${puzzle_dir}/replacement_library.json -dataset_path: ??? # path to v0.4_mini +dataset_path: ??? # path to v0.4_mini skip_realize_model: false @@ -21,10 +21,10 @@ calc_subblock_stats: batch_sizes: [64, 96, 128] prefill_seq_len: 4096 generation_seq_len: 4096 - num_active_tokens_override: # Optional override for sequence lengths + num_active_tokens_override: # Optional override for sequence lengths prefill_queue_size: 0 allocate_prefill_query: false - benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking merge_with_existing_stats: false subblock_stats_filename: "subblock_stats.json" moe_stats_filename: "moe_stats.json" @@ -54,8 +54,6 @@ mip: # puzzle_profile: objective: metrics.cosine_embedding_loss_hidden_states bigger_is_better: false - num_solutions: 1 - minimal_diversity: 2 subblock_stats_args: - batch_size: 96 @@ -79,10 +77,7 @@ mip: target_memory: 780_000 # 78_000 mip_constraints: - use_greedy_search: false - is_multi_layer_puzzle: true metric_overrides: - constrain_search_func: max_seconds_per_solution: 60 realize_model: @@ -90,10 +85,10 @@ realize_model: tokenizer_name: ${to_path:${teacher_dir}} replacement_library_path: ${replacement_library_path} save_models: true - solutions_path: # Filled dynamically + solutions_path: # Filled dynamically # Validate params - skip_validation: false # To enable validation of the model solution set `skip_validation` as False + skip_validation: false # To enable validation of the model solution set `skip_validation` as False eval_samples: 2 micro_batch_size: 1 dataset_path: ${dataset_path}/valid diff --git a/tests/gpu/torch/_compress/test_compress.py b/tests/gpu/torch/_compress/test_compress.py index b00be24857..e40756602a 100644 --- a/tests/gpu/torch/_compress/test_compress.py +++ b/tests/gpu/torch/_compress/test_compress.py @@ -33,20 +33,6 @@ # # Note: Bypass is disabled now in the test. -# How to run this test (currently only supported internally at Nvidia). -# -# Have both modelopt and puzzle source code in the same directory: -# /workspace/modelopt -# /workspace/puzzletron -# -# submit_job --partition interactive --time 0 \ -# --image gitlab-master.nvidia.com/deci/puzzletron:modelopt_main \ -# --workdir $MODELOPT SRC DIRECTORY --interactive --gpu 1 -# -# export PYTHONPATH=$PYTHONPATH:.:/workspace/puzzletron/v1 -# -# pytest -s -v ./tests/gpu/torch/_compress/test_compress.py::test_compress -o addopts="" - def test_compress(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( From 67489f423ef7de8275ab64d8d182e7c86ebb8a18 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 11 Dec 2025 16:12:54 +0100 Subject: [PATCH 24/62] Fix a bug in IterativeChannelContributionHook + tools for activation hooks analysis (#670) ## What does this PR do? Fix a bug in IterativeChannelContributionHook + tools for activation hooks analysis --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Daniel Korzekwa Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../nas/plugins/megatron_hooks/__init__.py | 18 ++ .../megatron_hooks/compare_module_outputs.py | 291 ++++++++++++++++++ .../{ => megatron_hooks}/megatron_hooks.py | 117 ++++++- .../megatron_hooks/megatron_hooks_analysis.py | 104 +++++++ .../test_megatron_hooks.py | 0 .../test_megatron_hooks_analysis.py | 217 +++++++++++++ 6 files changed, 743 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/nas/plugins/megatron_hooks/__init__.py create mode 100644 modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py rename modelopt/torch/nas/plugins/{ => megatron_hooks}/megatron_hooks.py (76%) create mode 100644 modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks_analysis.py rename tests/gpu/torch/nas/plugins/{ => megatron_hooks}/test_megatron_hooks.py (100%) create mode 100644 tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks_analysis.py diff --git a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py new file mode 100644 index 0000000000..1d19308edf --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Forward hooks for estimating importance scores for pruning.""" + +from .megatron_hooks import * +from .megatron_hooks_analysis import * diff --git a/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py b/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py new file mode 100644 index 0000000000..316aff76ff --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Compare module output tensors from different model variants. + +This module provides: +1. OutputSaveHook - A PyTorch hook to capture module outputs during forward pass +2. Comparison utilities - Compute RMSE and cosine similarity between saved outputs + +Usage Example: +-------------- + +Step 1: Capture outputs from multiple layers: + + from modelopt.torch.nas.plugins.megatron_hooks.compare_module_outputs import ( + OutputSaveHook, + save_multi_layer_outputs, + ) + + # Register hooks on all target layers + hooks = {} + for name, module in model.named_modules(): + if name.endswith('mlp.linear_fc2'): + hook = OutputSaveHook(layer_name=name) + module.register_forward_hook(hook) + hooks[name] = hook + + # Run inference/training + model(input_data) + + # Save all layer outputs + save_multi_layer_outputs(hooks, "output_unpruned.pt") + +Step 2: Compare outputs from different model variants: + + python compare_module_outputs.py \ + --reference output_unpruned.pt \ + --compare output_l2norm.pt \ + --output-json comparison_stats.json + +The saved file format: +{ + 'decoder.layers.0.mlp.linear_fc2': Tensor([steps, seq_len, batch, hidden]), + 'decoder.layers.1.mlp.linear_fc2': Tensor([...]), + ... + 'metadata': {'num_layers': N, 'num_steps': M, 'layer_names': [...]} +} +""" + +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class OutputSaveHook: + """Hook to capture and save module outputs during forward pass.""" + + def __init__(self, layer_name: str) -> None: + """Initialize the output save hook. + + Args: + layer_name: Hierarchical name of the layer (e.g., 'decoder.layers.0.mlp.linear_fc2'). + """ + self.layer_name = layer_name + self.saved_outputs: list[torch.Tensor] = [] + + def __call__( + self, + module: nn.Module, + args: tuple[torch.Tensor, ...], + output: torch.Tensor | tuple[torch.Tensor, ...], + ) -> None: + """Capture and save module output during forward pass. + + Args: + module: The PyTorch module being hooked. + args: Input arguments to the module's forward pass. + output: Output tensor(s) from the module's forward pass. + """ + # Handle tuple outputs (e.g., output, bias) + out = output[0] if isinstance(output, tuple) else output + self.saved_outputs.append(out.detach().cpu()) + + def get_outputs_list(self) -> list[torch.Tensor]: + """Return saved outputs as a list.""" + return self.saved_outputs + + +def save_multi_layer_outputs(hooks: dict[str, OutputSaveHook], path: str) -> None: + """Save outputs from multiple layers to a single file. + + Args: + hooks: Dictionary mapping layer names to their hooks. + path: Path to save the outputs. + """ + output_dict = {name: hook.get_outputs_list() for name, hook in hooks.items()} + + # Add metadata + output_dict["metadata"] = { + "num_layers": len(hooks), + # Number of forward passes (generation steps) - all hooks have same count, so use first hook + "num_steps": len(next(iter(hooks.values())).saved_outputs) if hooks else 0, + "layer_names": list(hooks.keys()), + } + + torch.save(output_dict, path) + print(f"\nSaved outputs from {len(hooks)} layers to {path}") + for name, data in output_dict.items(): + if name != "metadata": + print(f" {name}: list of {len(data)} tensors") + + +def compute_rmse(tensor1: torch.Tensor, tensor2: torch.Tensor) -> float: + """Compute Root Mean Square Error between two tensors.""" + mse = torch.mean((tensor1 - tensor2) ** 2) + rmse = torch.sqrt(mse) + return rmse.item() + + +def compute_cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor) -> dict: + """Compute average cosine similarity between two tensors.""" + # Flatten to 2D for cosine similarity computation + t1_flat = tensor1.reshape(-1, tensor1.shape[-1]) + t2_flat = tensor2.reshape(-1, tensor2.shape[-1]) + + # Compute cosine similarity per position + cos_sim = F.cosine_similarity(t1_flat, t2_flat, dim=-1) + + return { + "mean": cos_sim.mean().item(), + "min": cos_sim.min().item(), + "max": cos_sim.max().item(), + "std": cos_sim.std().item(), + } + + +def main(): + """Compare module output tensors from different model variants.""" + parser = argparse.ArgumentParser( + description="Compare module output tensors from different model variants", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--reference", + type=str, + required=True, + help="Path to reference output tensor (e.g., unpruned model)", + ) + parser.add_argument( + "--compare", + type=str, + required=True, + help="Path to output tensor to compare against reference", + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save comparison statistics as JSON", + ) + args = parser.parse_args() + + # Load reference data + print(f"\nLoading reference: {args.reference}") + ref_data = torch.load(args.reference, map_location="cpu") + + # Load comparison data + print(f"Loading compare: {args.compare}") + comp_data = torch.load(args.compare, map_location="cpu") + + # Compare multi-layer outputs + compare_multi_layer(ref_data, comp_data, args.output_json) + + +def compute_layer_metrics(ref_data: list, comp_data: list) -> dict: + """Compute RMSE and cosine similarity for a layer's outputs. + + Args: + ref_data: List of reference tensors. + comp_data: List of comparison tensors. + + Returns: + Dictionary with metrics. + + Raises: + ValueError: If lengths don't match or tensor shapes don't match. + """ + if len(ref_data) != len(comp_data): + raise ValueError( + f"Length mismatch: reference has {len(ref_data)} samples, compare has {len(comp_data)}" + ) + + rmse_values = [] + cos_sim_values = [] + + for ref_tensor, comp_tensor in zip(ref_data, comp_data): + if ref_tensor.shape != comp_tensor.shape: + raise ValueError( + f"Shape mismatch at index {len(rmse_values)}: " + f"reference {ref_tensor.shape} vs compare {comp_tensor.shape}" + ) + rmse_values.append(compute_rmse(ref_tensor, comp_tensor)) + cos_sim = compute_cosine_similarity(ref_tensor, comp_tensor) + cos_sim_values.append(cos_sim["mean"]) + + return { + "rmse": sum(rmse_values) / len(rmse_values), + "cosine_sim": { + "mean": sum(cos_sim_values) / len(cos_sim_values), + "min": min(cos_sim_values), + "max": max(cos_sim_values), + "std": torch.tensor(cos_sim_values).std().item() if len(cos_sim_values) > 1 else 0.0, + }, + "num_samples": len(rmse_values), + } + + +def compare_multi_layer(ref_data: dict, comp_data: dict, output_json: str | None = None): + """Compare multi-layer outputs.""" + import json + + ref_layers = [k for k in ref_data if k != "metadata"] + comp_layers = [k for k in comp_data if k != "metadata"] + + if set(ref_layers) != set(comp_layers): + print("\nERROR: Layer mismatch!") + print(f"Reference layers: {ref_layers}") + print(f"Compare layers: {comp_layers}") + return + + results = {"aggregated": {"rmse": [], "cosine_sim_mean": []}, "per_layer": {}} + + # Per-layer comparison + for layer_name in sorted(ref_layers): + ref_layer_data = ref_data[layer_name] + comp_layer_data = comp_data[layer_name] + + metrics = compute_layer_metrics(ref_layer_data, comp_layer_data) + + results["per_layer"][layer_name] = metrics + results["aggregated"]["rmse"].append(metrics["rmse"]) + results["aggregated"]["cosine_sim_mean"].append(metrics["cosine_sim"]["mean"]) + + # Aggregated statistics + if results["aggregated"]["rmse"]: + rmse_array = torch.tensor(results["aggregated"]["rmse"]) + cos_sim_array = torch.tensor(results["aggregated"]["cosine_sim_mean"]) + + results["aggregated"]["rmse_stats"] = { + "mean": rmse_array.mean().item(), + "std": rmse_array.std().item(), + "min": rmse_array.min().item(), + "max": rmse_array.max().item(), + } + results["aggregated"]["cosine_sim_stats"] = { + "mean": cos_sim_array.mean().item(), + "std": cos_sim_array.std().item(), + "min": cos_sim_array.min().item(), + "max": cos_sim_array.max().item(), + } + results["aggregated"]["num_steps"] = ref_data.get("metadata", {}).get("num_steps", None) + results["aggregated"]["num_layers"] = len(rmse_array) + + # Save to JSON if requested + if output_json: + # Remove raw lists for JSON serialization + results["aggregated"].pop("rmse", None) + results["aggregated"].pop("cosine_sim_mean", None) + + with open(output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"Saved comparison results to {output_json}") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/nas/plugins/megatron_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py similarity index 76% rename from modelopt/torch/nas/plugins/megatron_hooks.py rename to modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py index 12e07c59ea..3bb1493950 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py @@ -23,6 +23,12 @@ from megatron.core.tensor_parallel.layers import RowParallelLinear from torch import nn +__all__ = [ + "IndependentChannelContributionHook", + "IterativeChannelContributionHook", + "MegatronL2NormHook", +] + def clear_gpu_memory(clear: bool) -> None: """Clear GPU memory cache if requested. @@ -167,6 +173,108 @@ def load_state_dict(self, state_dict: dict) -> None: self._activations = state_dict["activations"] +class IndependentChannelContributionHook(ForwardHook): + """Hook for channel importance estimation using weight norms and activation magnitudes. + + Computes channel importance as the product of: + - L2 norm of each column in the weight matrix (how much each input channel affects output) + - Mean absolute activation for each channel (how strongly each channel is activated) + + Args: + linear_layer: The linear projection layer to analyze. Can be either nn.Linear or + RowParallelLinear from megatron.core.tensor_parallel.layers. + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__( + self, + linear_layer: nn.Linear | RowParallelLinear, + max_size: int | None = None, + ): + """Initialize the independent channel contribution hook.""" + self.max_size = max_size + + weight_matrix = linear_layer.weight.float() + self.weight_norm = torch.linalg.vector_norm(weight_matrix, dim=0) + + # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) + if hasattr(linear_layer, "input_size"): + self.num_channels = linear_layer.input_size # Megatron-Core + else: + self.num_channels = linear_layer.in_features # PyTorch + + self.agg_channel_activations = torch.zeros( + size=(self.num_channels,), + dtype=torch.float32, + device=weight_matrix.device, + ) + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple + ) -> None: + """Accumulate mean absolute activations per channel. + + Args: + module: The module this hook is registered on. + args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] + (PyTorch batch-first format). + output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) + for parallel layers. + """ + activations = args[0] + + # Don't aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and activations.shape[-1] != self.max_size: + return + + mean_abs_channel_activations = ( + activations.abs().float().mean(dim=list(range(activations.ndim - 1))) + ) + self.agg_channel_activations[:] += mean_abs_channel_activations # shape [input_channels] + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert results to dict with channel importance scores. + + Returns: + Dict with "score" (weight_norm * activations), "weight_norm", and + "agg_channel_activations". + """ + return { + "score": (self.weight_norm * self.agg_channel_activations).cpu(), + "weight_norm": self.weight_norm.cpu(), + "agg_channel_activations": self.agg_channel_activations.cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return importance scores as a tensor. + + Returns: + Tensor of importance scores (weight_norm * activations), one per channel. + """ + return self.to_dict()["score"] + + def state_dict(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "agg_channel_activations": self.agg_channel_activations.cpu().clone(), + "weight_norm": self.weight_norm.cpu().clone(), + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.agg_channel_activations = state_dict["agg_channel_activations"].to( + self.agg_channel_activations.device + ) + # weight_norm should be the same as it's derived from the model weights + # but we can verify it matches + expected_weight_norm = state_dict["weight_norm"].to(self.weight_norm.device) + if not torch.allclose(self.weight_norm, expected_weight_norm, rtol=1e-5): + raise AssertionError( + "weight_norm mismatch during state loading - model weights may have changed" + ) + + def get_pruning_schedule(num_channels, pruning_iters): """Spending decreases monotonically when num_channels >= pruning_iters. @@ -309,10 +417,11 @@ def __call__( del contribution, contribution_squared clear_gpu_memory(clear=self.clear_gpu_memory) - if n_channels_to_prune == 0: - self.agg_cont_per_channel += mean_cont_per_channel - else: - _, worst_indices = torch.topk(mean_cont_per_channel, n_channels_to_prune, largest=False) + self.agg_cont_per_channel += mean_cont_per_channel + if n_channels_to_prune > 0: + _, worst_indices = torch.topk( + self.agg_cont_per_channel, n_channels_to_prune, largest=False + ) worst_indices_list = worst_indices.tolist() assert not set(self.pruned_channels).intersection(set(worst_indices_list)) self.pruned_channels.extend(worst_indices_list) diff --git a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks_analysis.py b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks_analysis.py new file mode 100644 index 0000000000..caf5eed898 --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks_analysis.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Analysis tools for evaluating importance scores from megatron hooks.""" + +import torch +import torch.nn.functional as F +from torch import nn + +__all__ = ["evaluate_importance_scores"] + + +def evaluate_importance_scores( + linear_layer: nn.Linear, + activations_batches: list[torch.Tensor], + importance_scores: torch.Tensor, + prune_ratio: float = 0.2, +) -> dict[str, float]: + """Compute reconstruction error after pruning input channels of a linear layer. + + This function simulates channel pruning by zeroing out input channels identified as + least important, then measures how much the layer's output changes. + + Args: + linear_layer: The linear layer to analyze with shape (out_features, in_features). + For example: nn.Linear(in_features=1024, out_features=4096) + activations_batches: List of input activation tensors. + Each tensor has shape [seq_len, batch_size, in_features] (Megatron format). + The last dimension must match linear_layer.in_features. + Example: List of [16, 8, 1024] tensors + importance_scores: Importance score for each input channel (feature). + Shape: [in_features]. Lower scores = less important. + Example: [1024] tensor with one score per input feature + prune_ratio: Fraction of input channels to prune (default: 0.2 means prune 20%). + + Returns: + Dictionary containing averaged metrics across all activation batches: + - rmse: Root mean squared error between original and pruned output + - cosine_similarity: Cosine similarity between original and pruned output + - num_pruned: Number of input channels pruned + + Example: + >>> layer = nn.Linear(in_features=1024, out_features=4096) + >>> # Collect multiple batches for robust evaluation + >>> activations_list = [torch.randn(16, 8, 1024) for _ in range(100)] + >>> scores = torch.randn(1024) # one score per input feature + >>> metrics = evaluate_importance_scores(layer, activations_list, scores, 0.2) + >>> print(f"RMSE: {metrics['rmse']:.4f}, Pruned: {metrics['num_pruned']} channels") + + Note: + - This simulates pruning (zeros out inputs) without modifying layer weights + - "Channels" refers to INPUT features, not output features + + """ + num_channels = importance_scores.shape[0] + num_to_prune = int(num_channels * prune_ratio) + + # Identify channels to prune (lowest scoring = least important) + _, channels_to_prune = torch.topk(importance_scores, num_to_prune, largest=False) + + # Compute metrics for each batch and average + rmse_values = [] + cosine_values = [] + + for activations in activations_batches: + # Get original output + original_output = linear_layer(activations) + + # Prune by zeroing out identified channels + pruned_activations = activations.clone() + pruned_activations[..., channels_to_prune] = 0 + + # Get pruned output + pruned_output = linear_layer(pruned_activations) + + # Compute metrics for this batch + rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item() + rmse_values.append(rmse) + + # Cosine similarity (flatten to vectors) + original_flat = original_output.reshape(-1) + pruned_flat = pruned_output.reshape(-1) + cosine = F.cosine_similarity( + original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1 + ).item() + cosine_values.append(cosine) + + # Return averaged metrics + return { + "rmse": sum(rmse_values) / len(rmse_values), + "cosine_similarity": sum(cosine_values) / len(cosine_values), + "num_pruned": num_to_prune, + } diff --git a/tests/gpu/torch/nas/plugins/test_megatron_hooks.py b/tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks.py similarity index 100% rename from tests/gpu/torch/nas/plugins/test_megatron_hooks.py rename to tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks.py diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks_analysis.py b/tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks_analysis.py new file mode 100644 index 0000000000..4f075c9dd2 --- /dev/null +++ b/tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks_analysis.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for megatron hooks analysis tools.""" + +import pytest +import torch +import torch.nn as nn +from _test_utils.import_helper import skip_if_no_megatron + +skip_if_no_megatron() + +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from megatron.core.parallel_state import initialize_model_parallel + +from modelopt.torch.nas.plugins.megatron_hooks import ( + IndependentChannelContributionHook, + IterativeChannelContributionHook, + MegatronL2NormHook, + evaluate_importance_scores, +) + + +def test_evaluate_importance_scores_basic(): + """Test basic functionality of importance score evaluation with synthetic scores.""" + torch.manual_seed(42) + + # Create a simple linear layer (same dimensions as other tests for comparability) + layer = nn.Linear(in_features=50, out_features=30, bias=False) + + # Create synthetic hook that generates sequential importance scores + hook = SyntheticImportanceHook(num_features=50) + + # Use shared helper to run evaluation + metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) + + print(f"[SyntheticImportanceHook] Metrics: {metrics}") + + # Check values with deterministic seed + assert metrics["num_pruned"] == 20 # 40% of 50 = 20 + assert metrics["rmse"] == pytest.approx(0.3689444, rel=1e-5) + assert metrics["cosine_similarity"] == pytest.approx(0.77117118, rel=1e-5) + + +def _test_evaluate_importance_scores_with_l2_norm_hook(rank, size): + """Test evaluate_importance_scores with MegatronL2NormHook.""" + # Initialize Megatron parallel state + initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + torch.manual_seed(42) + + # Create layer and hook + layer = nn.Linear(in_features=50, out_features=30, bias=False) + hook = MegatronL2NormHook(max_size=None) + + # Run evaluation + metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) + + print(f"[L2NormHook] Metrics: {metrics}") + + # L2NormHook specific assertions + assert metrics["num_pruned"] == 20 # 40% of 50 = 20 + assert metrics["rmse"] == pytest.approx(0.3616334, rel=1e-5) + assert metrics["cosine_similarity"] == pytest.approx(0.7814186, rel=1e-5) + + +def _test_evaluate_importance_scores_with_iterative_channel_contribution_hook(rank, size): + """Test evaluate_importance_scores with IterativeChannelContributionHook.""" + # Initialize Megatron parallel state + initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + torch.manual_seed(42) + + # Create layer and hook + layer = nn.Linear(in_features=50, out_features=30, bias=False) + activation_hooks_kwargs = { + "validation_full_iters": 1000, + "clear_gpu_memory": False, + "calibration_method": None, + } + hook = IterativeChannelContributionHook(layer, activation_hooks_kwargs) + + # Run evaluation + metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) + + print(f"[IterativeChannelContributionHook] Metrics: {metrics}") + + # Iterative channel contribution hook specific assertions + assert metrics["num_pruned"] == 20 # 40% of 50 = 20 + assert metrics["rmse"] == pytest.approx(0.339014, rel=1e-5) + assert metrics["cosine_similarity"] == pytest.approx(0.8110392, rel=1e-5) + + +def _test_evaluate_importance_scores_with_independent_channel_contribution_hook(rank, size): + """Test evaluate_importance_scores with IndependentChannelContributionHook.""" + # Initialize Megatron parallel state + initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + torch.manual_seed(42) + + # Create layer and hook + layer = nn.Linear(in_features=50, out_features=30, bias=False) + hook = IndependentChannelContributionHook(layer) + + # Run evaluation + metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) + + print(f"[IndependentChannelContributionHook] Metrics: {metrics}") + + # Independent channel contribution hook specific assertions + assert metrics["num_pruned"] == 20 # 40% of 50 = 20 + assert metrics["rmse"] == pytest.approx(0.3385471, rel=1e-5) + assert metrics["cosine_similarity"] == pytest.approx(0.8116209, rel=1e-5) + + +def test_evaluate_importance_scores_with_l2_norm_hook(): + """Test evaluate_importance_scores using MegatronL2NormHook.""" + spawn_multiprocess_job( + size=1, + job=_test_evaluate_importance_scores_with_l2_norm_hook, + backend="gloo", + ) + + +def test_evaluate_importance_scores_with_iterative_channel_contribution_hook(): + """Test evaluate_importance_scores using IterativeChannelContributionHook.""" + spawn_multiprocess_job( + size=1, + job=_test_evaluate_importance_scores_with_iterative_channel_contribution_hook, + backend="gloo", + ) + + +def test_evaluate_importance_scores_with_independent_channel_contribution_hook(): + """Test evaluate_importance_scores using IndependentChannelContributionHook.""" + spawn_multiprocess_job( + size=1, + job=_test_evaluate_importance_scores_with_independent_channel_contribution_hook, + backend="gloo", + ) + + +def _run_hook_and_evaluate( + layer: nn.Linear, + hook, + num_iterations: int, + prune_ratio: float, +) -> dict: + """Shared helper to run hook, collect scores, and evaluate. + + Args: + layer: Linear layer to test + hook: Hook instance (already created) + num_iterations: Number of forward passes + prune_ratio: Fraction of channels to prune + + Returns: + Dictionary with evaluation metrics + """ + handle = layer.register_forward_hook(hook) # Store the handle + + # Run forward passes + all_activations = [] + for _ in range(num_iterations): + activations = torch.randn( + 16, 8, layer.in_features + ) # seq=16, batch=8, in_features=50 (Megatron format) + all_activations.append(activations) + _ = layer(activations) + + # Get importance scores from hook + importance_scores = hook.accumulate() + + # Remove the hook before evaluation to avoid triggering it again + handle.remove() + + # Evaluate the importance scores by simulating pruning on all collected activations + # Pass the list of activations to compute averaged metrics across batches + metrics = evaluate_importance_scores( + layer, + all_activations, # List of activation batches + importance_scores, + prune_ratio=prune_ratio, + ) + + return metrics + + +class SyntheticImportanceHook: + """Synthetic hook that generates sequential importance scores for testing. + + This is a simple mock hook that doesn't compute real importance, + just returns torch.arange(num_features) to test the evaluation pipeline. + """ + + def __init__(self, num_features: int): + """Initialize with the number of features.""" + self.num_features = num_features + + def __call__(self, module, args, output): + """Hook callback - does nothing for synthetic hook.""" + + def accumulate(self) -> torch.Tensor: + """Return synthetic importance scores: [0, 1, 2, ..., num_features-1].""" + return torch.arange(self.num_features, dtype=torch.float32) From 1d8bd20fe6162b0ac46d7c557ce61325fe36a882 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Fri, 12 Dec 2025 00:59:43 +0530 Subject: [PATCH 25/62] Remove runtime.py and directly use torch dist utils + remove unused functions (#667) ## What does this PR do? - Remove `runtime.py` and directly use `modelopt.torch.utils.distributed` functions - Read model_dtype and autocast_dtype from validate_model_defaults.yaml instead of runtime object - Remove more unused functions - Remove unnecessary parse_args for intermediate steps, and improve docstrings ## Testing - No change in functionality - Existing tests passing --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../validate_model_defaults.yaml | 2 + examples/compress/main.py | 149 ++--- .../activation_hooks/hooks.py | 39 +- .../score_pruning_activations.py | 50 +- .../_compress/build_library_and_stats.py | 19 - modelopt/torch/_compress/compress.py | 20 +- .../_compress/dataset/prepare_dataset.py | 3 +- .../_compress/mip/mip_and_realize_models.py | 53 +- modelopt/torch/_compress/mip/run_puzzle.py | 4 +- .../nas/plugins/compress_nas_plugin.py | 33 +- .../torch/_compress/pruning/pruning_ckpts.py | 12 - .../build_replacement_library.py | 14 - .../replacement_library.py | 29 +- modelopt/torch/_compress/scoring/scoring.py | 20 +- modelopt/torch/_compress/sewing_kit/common.py | 19 - .../_compress/sewing_kit/passage/core.py | 7 - modelopt/torch/_compress/sewing_kit/utils.py | 99 +--- .../calc_subblock_params_and_memory.py | 12 +- .../subblock_stats/calc_subblock_stats.py | 30 +- .../tools/bypassed_training/child_init.py | 2 - .../init_child_from_parent.py | 44 -- modelopt/torch/_compress/tools/hydra.py | 54 -- modelopt/torch/_compress/tools/runtime.py | 556 ------------------ .../tools/sharded_checkpoint_utils.py | 95 +-- .../torch/_compress/tools/validate_model.py | 174 +++--- ...validate_puzzle_with_multi_replacements.py | 172 +++--- .../torch/_compress/tools/validation_utils.py | 38 +- .../_compress/utils/checkpoint_manager.py | 23 +- .../torch/_compress/utils/data/dataloaders.py | 122 ---- modelopt/torch/_compress/utils/dist_utils.py | 30 - modelopt/torch/_compress/utils/utils.py | 2 +- .../utils/validate_runtime_pipeline.py | 189 +----- modelopt/torch/_compress/utils/validation.py | 269 --------- modelopt/torch/utils/distributed.py | 29 + setup.py | 1 - .../torch/_compress/compress_test_utils.py | 9 +- .../_compress/nas/plugins/test_nas_convert.py | 178 +++--- .../_compress/nas/plugins/test_nas_search.py | 126 ++-- .../configs/validate_model_defaults.yaml | 2 + tests/gpu/torch/_compress/test_compress.py | 100 ++-- 40 files changed, 619 insertions(+), 2210 deletions(-) delete mode 100644 modelopt/torch/_compress/sewing_kit/common.py delete mode 100644 modelopt/torch/_compress/tools/hydra.py delete mode 100644 modelopt/torch/_compress/tools/runtime.py delete mode 100644 modelopt/torch/_compress/utils/dist_utils.py diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml index 572331a84f..9e662c4e13 100644 --- a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -1,3 +1,5 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model block_size: 8192 bos_rate: 0.5 data_column: messages diff --git a/examples/compress/main.py b/examples/compress/main.py index c8b287fccd..2c3343c374 100644 --- a/examples/compress/main.py +++ b/examples/compress/main.py @@ -29,18 +29,18 @@ """ import argparse -import datetime +from datetime import timedelta from pathlib import Path -import mip_and_realize_models -import torch -from puzzle_tools.hydra_utils import register_hydra_resolvers - +import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.hydra_utils import ( + initialize_hydra_config_for_dir, + register_hydra_resolvers, +) from modelopt.torch._compress.tools.logger import mprint -from tests.utils.test_utils import initialize_hydra_config_for_dir def parse_args(): @@ -70,50 +70,52 @@ def run_full_compress(hydra_config_path: str): config_path: Path to the YAML configuration file """ mprint("Compress Progress 1/8: starting compression pipeline") - with NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)): - # Register Hydra custom resolvers (needed for config resolution) - register_hydra_resolvers() - - hydra_config_path = Path(hydra_config_path).resolve() - hydra_config_dir = str(hydra_config_path.parent) - hydra_config_name = hydra_config_path.stem - - # Load hydra config - hydra_cfg = initialize_hydra_config_for_dir( - config_dir=hydra_config_dir, - config_name=hydra_config_name, - overrides=[], - ) - - # Convert model (convert from HF to DeciLM, score pruning activations, - # prune the model and save pruned checkpoints) - input_model = CompressModel() - converted_model = mtn.convert( - input_model, - mode=[ - ( - "compress", - { - "puzzle_dir": str(hydra_cfg.puzzle_dir), - "input_model_path": hydra_cfg.input_hf_model_path, - "hydra_config_dir": hydra_config_dir, - "hydra_config_name": hydra_config_name, - "dataset_path": str(hydra_cfg.dataset_path), - }, - ) - ], - ) - - # Run NAS search (build replacement library and compute stats, - # compute one block scores, run MIP and realize models) - mtn.search( - converted_model, - constraints={}, # this is not used as the search space is defined in the hydra config - dummy_input=None, # Not used - config={}, # this is not used as the search space is defined in the hydra config - ) - - mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") + dist.setup(timeout=timedelta(10)) + + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # Convert model (convert from HF to DeciLM, score pruning activations, + # prune the model and save pruned checkpoints) + input_model = CompressModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(hydra_cfg.puzzle_dir), + "input_model_path": hydra_cfg.input_hf_model_path, + "hydra_config_dir": hydra_config_dir, + "hydra_config_name": hydra_config_name, + "dataset_path": str(hydra_cfg.dataset_path), + }, + ) + ], + ) + + # Run NAS search (build replacement library and compute stats, + # compute one block scores, run MIP and realize models) + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + dist.cleanup() + mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") def run_mip_only(hydra_config_path: str): @@ -125,30 +127,29 @@ def run_mip_only(hydra_config_path: str): Args: hydra_config_path: Path to the YAML configuration file """ + dist.setup(timeout=timedelta(10)) + + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API + mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Register Hydra custom resolvers (needed for config resolution) - register_hydra_resolvers() - - hydra_config_path = Path(hydra_config_path).resolve() - hydra_config_dir = str(hydra_config_path.parent) - hydra_config_name = hydra_config_path.stem - - # Load hydra config - hydra_cfg = initialize_hydra_config_for_dir( - config_dir=hydra_config_dir, - config_name=hydra_config_name, - overrides=[], - ) - - # mip_and_realize_models (distributed processing) - # TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API - mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) - - mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") + dist.cleanup() + mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") def main(): diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py index 6339d55ab6..510f691111 100644 --- a/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py @@ -18,7 +18,6 @@ activation scoring for pruning. """ -import argparse import gc import json from abc import ABC, abstractmethod @@ -30,6 +29,8 @@ from omegaconf import DictConfig, OmegaConf from torch import nn +import modelopt.torch.utils.distributed as dist + # BlockConfig used at runtime, not just type hints (lines 680, 790) from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import ( @@ -38,7 +39,6 @@ from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMRMSNorm from modelopt.torch._compress.tools.logger import aprint from modelopt.torch._compress.tools.robust_json import json_dump -from modelopt.torch._compress.tools.runtime import IRuntime def clear_gpu_memory(clear: bool) -> None: @@ -97,9 +97,8 @@ def dump_activations_logs( cls: type["ActivationsHook"], activation_hooks: dict[str, "ActivationsHook"], activations_log_dir: Path | str, - args: argparse.Namespace, - runtime: IRuntime | None, - ): + args: DictConfig, + ) -> None: """ Default implementation for dumping final activation scores logs to disk. This is called only at the end of scoring to save final results. @@ -107,7 +106,7 @@ def dump_activations_logs( activations_log_dir = Path(activations_log_dir) activations_log_dir.mkdir(exist_ok=True, parents=True) - rank = runtime.global_rank if runtime is not None else 0 + rank = dist.rank() activations_log_path = activations_log_dir / f"rank_{rank}.pth" activations_log = { module_name: hook.to_dict() for module_name, hook in activation_hooks.items() @@ -116,14 +115,8 @@ def dump_activations_logs( if rank == 0: args.activation_hooks_kwargs.pop("model") - json_dump( - OmegaConf.to_container(args, resolve=True) - if isinstance(args, DictConfig) - else vars(args), - activations_log_dir / "args.json", - ) - if runtime is not None: - runtime.wait_for_everyone() # rank 0 will not wait before dumping args.json + json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") + dist.barrier() aprint(f"Dumped final activations log to {activations_log_path}") @@ -132,8 +125,7 @@ def save_hook_states( cls: type["ActivationsHook"], activation_hooks: dict[str, "ActivationsHook"], activations_log_dir: Path | str, - runtime: IRuntime | None, - ): + ) -> None: """ Save hook states for checkpointing (separate from final results). This can be called periodically during scoring. @@ -141,7 +133,7 @@ def save_hook_states( """ activations_log_dir = Path(activations_log_dir) activations_log_dir.mkdir(exist_ok=True, parents=True) - rank = runtime.global_rank if runtime is not None else 0 + rank = dist.rank() hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth" hook_states = { @@ -461,29 +453,28 @@ def dump_activations_logs( cls: type["LayerNormContributionHook"], activation_hooks: dict[str, "ActivationsHook"], activations_log_dir: Path | str, - args: argparse.Namespace, - runtime: IRuntime | None, - ): + args: DictConfig, + ) -> None: """ At the end of the default implementation of dumping activation scores to disc, save aggregated channel importance results. """ - super().dump_activations_logs(activation_hooks, activations_log_dir, args, runtime) + super().dump_activations_logs(activation_hooks, activations_log_dir, args) - rank = runtime.global_rank if runtime is not None else 0 + rank = dist.rank() if rank == 0: LayerNormContributionHook._save_channel_importance_results( activation_hooks, activations_log_dir, args ) - runtime.wait_for_everyone() + dist.barrier() @staticmethod def _save_channel_importance_results( activation_hooks: dict[str, ActivationsHook], activations_log_dir: Path, - args: argparse.Namespace, + args: DictConfig, ) -> None: """ Save channel importance results from activation hooks. diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py index 4a276e8e82..f271a5f4f9 100644 --- a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -15,16 +15,12 @@ from pathlib import Path -import hydra import torch from omegaconf import DictConfig -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.validate_model import validate_model -from modelopt.torch._compress.utils.dist_utils import is_distributed -from modelopt.torch._compress.utils.parsing import format_global_config def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: @@ -50,23 +46,20 @@ def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: return method in supported_methods -def check_scoring_completion( - activations_log_dir: str, runtime, activation_hooks_kwargs=None -) -> bool: +def check_scoring_completion(activations_log_dir: str, activation_hooks_kwargs=None) -> bool: """ Check if scoring is already completed by looking for the expected output files. Also checks if the scoring method is safe for resume. Args: activations_log_dir: Directory where activation logs should be stored - runtime: Runtime object for distributed processing activation_hooks_kwargs: Hook configuration to check if resume is safe Returns: bool: True if scoring is completed (has rank files and args.json) """ - # Only check completion on main process (or if no distributed runtime) - if runtime is None or runtime.is_main_process: + # Only check completion on main process + if dist.is_master(): log_dir = Path(activations_log_dir) # Check if directory exists @@ -95,14 +88,13 @@ def check_scoring_completion( return False -def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: +def should_skip_scoring_completely(cfg: DictConfig) -> bool: """ Determine if we should skip scoring entirely (only if 100% complete). Partial progress should proceed to validate_model for proper resume. Args: cfg: Configuration object - runtime: Runtime object for distributed processing Returns: bool: True if we should skip scoring (100% completed), False if we should run/resume it @@ -123,11 +115,11 @@ def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: # Check if scoring is already completed is_completed = check_scoring_completion( - cfg.pruning.activations_log_dir, runtime, activation_hooks_kwargs + cfg.pruning.activations_log_dir, activation_hooks_kwargs ) # Broadcast the result to all processes in distributed mode - if runtime is not None and runtime.world_size > 1: + if dist.size() > 1: should_skip = [is_completed] # Use list for mutable object torch.distributed.broadcast_object_list(should_skip, src=0) is_completed = should_skip[0] @@ -141,34 +133,12 @@ def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: # Old progress tracking removed - checkpoint manager handles all progress tracking -def launch_score_activations(cfg: DictConfig, runtime): +def launch_score_activations(cfg: DictConfig): # Check if we should skip scoring entirely (only if 100% complete) - if should_skip_scoring_completely(cfg, runtime): + if should_skip_scoring_completely(cfg): return mprint("Starting pruning activation scoring...") # The checkpoint manager inside validate_model handles all progress tracking - validate_model(args=cfg.pruning, runtime=runtime) - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - mprint(format_global_config(cfg, title="Score Pruning Activations")) - - _runtime = ( - NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") - ) - if is_distributed() - else BaseRuntime(dtype=torch.bfloat16) - ) - with _runtime as runtime: - launch_score_activations(cfg, runtime) - runtime.wait_for_everyone() - - -if __name__ == "__main__": - register_hydra_resolvers() - main() + validate_model(args=cfg.pruning, pipeline_parallel=True) diff --git a/modelopt/torch/_compress/build_library_and_stats.py b/modelopt/torch/_compress/build_library_and_stats.py index f0735c98ff..28e0f386c2 100644 --- a/modelopt/torch/_compress/build_library_and_stats.py +++ b/modelopt/torch/_compress/build_library_and_stats.py @@ -88,22 +88,3 @@ def launch_build_library_and_stats(cfg: DictConfig) -> None: mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.subblock_stats_filename}") if hasattr(cfg.calc_subblock_stats, "moe_stats_filename"): mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.moe_stats_filename}") - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - """ - Main entry point for the unified build library and stats command. - - This function uses Hydra for configuration management and runs both - build_replacement_library and calc_subblock_stats in sequence. - """ - cfg = hydra.utils.instantiate(cfg) - mprint("Unified Build Library and Stats Configuration:") - mprint(format_global_config(cfg)) - launch_build_library_and_stats(cfg) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 8504631cbc..21e9df2af0 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -27,12 +27,12 @@ import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch._compress.scoring.scoring as scoring +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir -from modelopt.torch._compress.tools.runtime import IRuntime def compress( - hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str, runtime: IRuntime + hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str ) -> DictConfig: """Compress a puzzletron model using the MIP-based NAS search algorithm. @@ -41,8 +41,6 @@ def compress( hydra_config (str): the corresponding hydra config file puzzle_dir (str): directory with a puzzletron model to compress dataset_path (str): dataset used for scoring and distillation - runtime: distributed runtime to use to run the compression steps, e.g., - NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)) Returns: Hydra config object after compressing the model. @@ -60,22 +58,22 @@ def compress( ) # Step 1: score_pruning_activations (distributed processing) - score_pruning_activations.launch_score_activations(hydra_cfg, runtime) + score_pruning_activations.launch_score_activations(hydra_cfg) # Step 2: pruning_ckpts (single process) - if runtime.global_rank == 0: + if dist.is_master(): pruning_ckpts.launch_prune_ckpt(hydra_cfg) - runtime.wait_for_everyone() + dist.barrier() # Step 4: build_library_and_stats (single process) - if runtime.global_rank == 0: + if dist.is_master(): build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - runtime.wait_for_everyone() + dist.barrier() # Step 5: calc_one_block_scores (distributed processing) - scoring.launch_scoring(hydra_cfg, runtime) + scoring.launch_scoring(hydra_cfg) # Step 6: mip_and_realize_models (distributed processing) - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) return hydra_cfg diff --git a/modelopt/torch/_compress/dataset/prepare_dataset.py b/modelopt/torch/_compress/dataset/prepare_dataset.py index 49d63d1227..072640777a 100644 --- a/modelopt/torch/_compress/dataset/prepare_dataset.py +++ b/modelopt/torch/_compress/dataset/prepare_dataset.py @@ -18,7 +18,8 @@ import datasets import fire import numpy as np -from logger import mprint + +from modelopt.torch._compress.tools.logger import mprint def process_and_save_dataset( diff --git a/modelopt/torch/_compress/mip/mip_and_realize_models.py b/modelopt/torch/_compress/mip/mip_and_realize_models.py index f6d77d2624..a3a1a84b91 100644 --- a/modelopt/torch/_compress/mip/mip_and_realize_models.py +++ b/modelopt/torch/_compress/mip/mip_and_realize_models.py @@ -19,19 +19,15 @@ from pathlib import Path from typing import List -import hydra import torch -import torch.distributed as dist from omegaconf import DictConfig +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.mip.run_puzzle import run_puzzle -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import BaseRuntime, IRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import ( validate_puzzle_solutions, ) -from modelopt.torch._compress.utils.dist_utils import is_distributed def launch_mip(cfg: DictConfig) -> List[str]: @@ -39,19 +35,18 @@ def launch_mip(cfg: DictConfig) -> List[str]: return solution_paths -def launch_realize_model(cfg: DictConfig, runtime: IRuntime): - validate_puzzle_solutions(args=cfg.realize_model, runtime=runtime) +def launch_realize_model(cfg: DictConfig): + validate_puzzle_solutions(args=cfg.realize_model) -def launch_mip_and_realize_model(cfg: DictConfig, runtime: IRuntime): +def launch_mip_and_realize_model(cfg: DictConfig): # Determine device for distributed operations (NCCL requires CUDA tensors) device = "cpu" - if runtime.world_size > 1 and dist.is_initialized(): - backend = dist.get_backend() - if backend == "nccl": + if dist.size() > 1: + if torch.distributed.get_backend() == "nccl": device = torch.cuda.current_device() - if runtime.is_main_process: + if dist.is_master(): solution_paths = launch_mip(cfg) length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long, device=device) else: @@ -59,39 +54,19 @@ def launch_mip_and_realize_model(cfg: DictConfig, runtime: IRuntime): length_tensor = torch.tensor([0], dtype=torch.long, device=device) if not cfg.skip_realize_model: - if runtime.world_size > 1: - dist.broadcast(length_tensor, src=0) + if dist.size() > 1: + torch.distributed.broadcast(length_tensor, src=0) list_length = length_tensor.item() - if runtime.global_rank != 0: + if not dist.is_master(): solution_paths = [None] * list_length - if runtime.world_size > 1: - dist.broadcast_object_list(solution_paths, src=0) + if dist.size() > 1: + torch.distributed.broadcast_object_list(solution_paths, src=0) for solution_path in solution_paths: mprint(f"Realize model for the solution: {solution_path}") cfg.realize_model.solutions_path = Path(solution_path) - launch_realize_model(cfg, runtime=runtime) - runtime.wait_for_everyone() - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - - _runtime = ( - NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") - ) - if is_distributed() - else BaseRuntime(dtype=torch.bfloat16) - ) - with _runtime as runtime: - launch_mip_and_realize_model(cfg, runtime) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() + launch_realize_model(cfg) + dist.barrier() diff --git a/modelopt/torch/_compress/mip/run_puzzle.py b/modelopt/torch/_compress/mip/run_puzzle.py index 5773349c11..4868479e23 100644 --- a/modelopt/torch/_compress/mip/run_puzzle.py +++ b/modelopt/torch/_compress/mip/run_puzzle.py @@ -226,7 +226,7 @@ def parse_args() -> argparse.Namespace: def run_single_puzzle_config( - args: argparse.Namespace, + args: argparse.Namespace | DictConfig, gathered_metrics: dict, subblock_stats: dict, subblock_stats_args: dict, @@ -426,7 +426,7 @@ def _get_minimal_unique_names(dicts: List[dict]) -> List[str]: return ["-".join(f"{k}_{d[k]}".replace(".", "_") for k in non_common_keys) for d in dicts] -def run_puzzle(args: argparse.Namespace) -> List[str]: +def run_puzzle(args: argparse.Namespace | DictConfig) -> List[str]: # Loads config from args/puzzle_profile if args.puzzle_profile is not None: with open(args.puzzle_profile) as f: diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 5c08c693a2..55b9d10b0f 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -29,6 +29,7 @@ import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch._compress.scoring.scoring as scoring +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress import build_library_and_stats from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( @@ -36,7 +37,6 @@ ) from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import NativeDdpRuntime from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.mode import ( @@ -99,13 +99,6 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR The output of this step will be used by mnt.search() to perform the NAS search. """ - - # NativeDdpRuntime must be initialized/closed from outside of this function, so we are - # NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it. - runtime = NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) - # Required for mtn.search() to read NAS configuration model.hydra_config_dir = config.hydra_config_dir model.hydra_config_name = config.hydra_config_name @@ -124,26 +117,26 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR # Convert Llama3 model to DeciLM model # TODO: Make it generic, do not call convert_llama3_to_decilm directly. - if runtime.global_rank == 0: + if dist.is_master(): mprint("Compress Progress 2/8: converting model from HF to DeciLM (single-gpu)") hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable convert_llama3_to_decilm( input_dir=config.input_model_path, output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, ) - runtime.wait_for_everyone() + dist.barrier() # Score_pruning_activations (distributed processing) mprint("Compress Progress 3/8: scoring pruning activations (multi-gpu)") - score_pruning_activations.launch_score_activations(hydra_cfg, runtime) + score_pruning_activations.launch_score_activations(hydra_cfg) # Prune the model and save pruned checkpoints - if runtime.global_rank == 0: + if dist.is_master(): mprint( "Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" ) pruning_ckpts.launch_prune_ckpt(hydra_cfg) - runtime.wait_for_everyone() + dist.barrier() return model, {} @@ -203,12 +196,6 @@ def default_state_dict(self) -> SearchStateDict: return {} def run_search(self) -> None: - # NativeDdpRuntime must be initialized/closed from outside of this function, so we are - # NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it. - runtime = NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) - # Load hydra config hydra_cfg = initialize_hydra_config_for_dir( config_dir=self.model.hydra_config_dir, @@ -220,17 +207,17 @@ def run_search(self) -> None: ) # Build_library_and_stats (single process) - if runtime.global_rank == 0: + if dist.is_master(): mprint( "Compress Progress 5/8: building replacement library and subblock statistics (single-gpu)" ) build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - runtime.wait_for_everyone() + dist.barrier() # Calc_one_block_scores (distributed processing) mprint("Compress Progress 6/8: calculating one block scores (multi-gpu)") - scoring.launch_scoring(hydra_cfg, runtime) + scoring.launch_scoring(hydra_cfg) # mip_and_realize_models (distributed processing) mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/_compress/pruning/pruning_ckpts.py b/modelopt/torch/_compress/pruning/pruning_ckpts.py index 4a0e5c15cd..b413a3f783 100644 --- a/modelopt/torch/_compress/pruning/pruning_ckpts.py +++ b/modelopt/torch/_compress/pruning/pruning_ckpts.py @@ -337,15 +337,3 @@ def launch_prune_ckpt(cfg: DictConfig): raise NotImplementedError( f"checkpoint pruning is not currently supported for target layer: {target_layer}" ) - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - mprint(cfg) - launch_prune_ckpt(cfg) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() diff --git a/modelopt/torch/_compress/replacement_library/build_replacement_library.py b/modelopt/torch/_compress/replacement_library/build_replacement_library.py index a8b2b7f9b6..760952a609 100644 --- a/modelopt/torch/_compress/replacement_library/build_replacement_library.py +++ b/modelopt/torch/_compress/replacement_library/build_replacement_library.py @@ -40,7 +40,6 @@ from pathlib import Path from typing import Any, Type -import hydra import pandas as pd from omegaconf import DictConfig @@ -59,7 +58,6 @@ is_valid_decilm_checkpoint, load_model_config, ) -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.robust_json import json_dump from modelopt.torch._compress.utils.parsing import format_global_config @@ -591,15 +589,3 @@ def _build_single_sequence_replacement_solutions( ) return solutions - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - mprint(format_global_config(cfg)) - launch_build_replacement_library(cfg) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() diff --git a/modelopt/torch/_compress/replacement_library/replacement_library.py b/modelopt/torch/_compress/replacement_library/replacement_library.py index ccfaaee0de..5e2fee6f0d 100644 --- a/modelopt/torch/_compress/replacement_library/replacement_library.py +++ b/modelopt/torch/_compress/replacement_library/replacement_library.py @@ -30,6 +30,7 @@ from safetensors.torch import load_file as safe_load_file from torch import nn +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, @@ -124,18 +125,11 @@ def create_model_config(self, layer_replacements: list[dict]): model_config = self.model_config.set_block_configs(block_configs) return model_config - def load_model( - self, - layer_replacements: list[dict], - world_size: int, - global_rank: int, - ) -> DeciLMForCausalLM: + def load_model(self, layer_replacements: list[dict]) -> DeciLMForCausalLM: block_configs, block_locations = extract_block_configs_and_locations(layer_replacements) model_config = self.model_config.set_block_configs(block_configs) - owned_block_indexes = _get_owned_block_indexes( - model_config.get_num_hidden_layers(), world_size, global_rank - ) + owned_block_indexes = _get_owned_block_indexes(model_config.get_num_hidden_layers()) model = create_dummy_model(model_config, self.dtype) is_first_shard = 0 in owned_block_indexes @@ -157,15 +151,10 @@ def load_model( self._move_inactive_blocks_to_cpu(active_blocks) return model - def load_checkpoint( - self, - checkpoint_dir: str | Path, - world_size: int, - global_rank: int, - ) -> DeciLMForCausalLM: + def load_checkpoint(self, checkpoint_dir: str | Path) -> DeciLMForCausalLM: checkpoint_dir = Path(checkpoint_dir).resolve() layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) - model = self.load_model(layer_replacements, world_size, global_rank) + model = self.load_model(layer_replacements) return model def _locate_replacements_of_entire_checkpoint(self, checkpoint_dir: str | Path) -> list[dict]: @@ -371,18 +360,18 @@ def _error_message_ensure_split(checkpoint_dir: Path) -> str: ) -def _get_owned_block_indexes(n_layer: int, world_size: int, global_rank: int) -> list[int]: +def _get_owned_block_indexes(n_layer: int) -> list[int]: last_process_blocks = np.array([n_layer - 1]) # less params in last gpu, leave room for logits - if world_size == 1: + if dist.size() == 1: # Only one process: assign everything (including the "last process" block) to rank 0 owned_block_indexes_per_process = [ np.concatenate([np.arange(n_layer - 1), last_process_blocks]) ] else: # Multiple processes: split n_layer-1 blocks, reserve the last for "last process" - owned_block_indexes_per_process = np.array_split(range(n_layer - 1), world_size - 1) + owned_block_indexes_per_process = np.array_split(range(n_layer - 1), dist.size() - 1) owned_block_indexes_per_process.append(last_process_blocks) - owned_block_indexes = owned_block_indexes_per_process[global_rank].tolist() + owned_block_indexes = owned_block_indexes_per_process[dist.rank()].tolist() return owned_block_indexes diff --git a/modelopt/torch/_compress/scoring/scoring.py b/modelopt/torch/_compress/scoring/scoring.py index f17b8cd3e3..5f745b3990 100644 --- a/modelopt/torch/_compress/scoring/scoring.py +++ b/modelopt/torch/_compress/scoring/scoring.py @@ -27,13 +27,12 @@ import torch from omegaconf import DictConfig +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import BaseRuntime, IRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import ( validate_puzzle_solutions, ) -from modelopt.torch._compress.utils.dist_utils import is_distributed def extract_solution_id(filename): @@ -73,26 +72,19 @@ def get_solutions_to_validate(cfg: DictConfig): return _solutions_to_validate -def launch_scoring(cfg: DictConfig, runtime: IRuntime): +def launch_scoring(cfg: DictConfig): cfg.scoring.solutions_to_validate = get_solutions_to_validate(cfg) mprint(f"Solutions to validate: {cfg.scoring.solutions_to_validate}") - validate_puzzle_solutions(args=cfg.scoring, runtime=runtime) + validate_puzzle_solutions(args=cfg.scoring) @hydra.main("", version_base="1.3") def main(cfg: DictConfig) -> None: cfg = hydra.utils.instantiate(cfg) mprint(cfg) - - _runtime = ( - NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") - ) - if is_distributed() - else BaseRuntime(dtype=torch.bfloat16) - ) - with _runtime as runtime: - launch_scoring(cfg, runtime) + dist.setup(timeout=cfg.nccl_timeout_minutes) + launch_scoring(cfg) + dist.cleanup() if __name__ == "__main__": diff --git a/modelopt/torch/_compress/sewing_kit/common.py b/modelopt/torch/_compress/sewing_kit/common.py deleted file mode 100644 index 5bc5732320..0000000000 --- a/modelopt/torch/_compress/sewing_kit/common.py +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -logger = logging.getLogger("sewing_kit") -logger.setLevel(logging.WARN) diff --git a/modelopt/torch/_compress/sewing_kit/passage/core.py b/modelopt/torch/_compress/sewing_kit/passage/core.py index 71164f061f..22c720b503 100644 --- a/modelopt/torch/_compress/sewing_kit/passage/core.py +++ b/modelopt/torch/_compress/sewing_kit/passage/core.py @@ -16,20 +16,13 @@ # mypy: ignore-errors from __future__ import annotations -import sys from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any, ContextManager, Iterable, Mapping, Optional, Union -try: - from typing import Self -except ImportError: - from typing_extensions import Self - import torch.nn as nn from typing_extensions import override -from ..common import logger from ..utils import ( ActivityContext, dynamo_skip, diff --git a/modelopt/torch/_compress/sewing_kit/utils.py b/modelopt/torch/_compress/sewing_kit/utils.py index ff47c289b6..25ee8c9eab 100644 --- a/modelopt/torch/_compress/sewing_kit/utils.py +++ b/modelopt/torch/_compress/sewing_kit/utils.py @@ -16,7 +16,7 @@ from __future__ import annotations import inspect -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from contextlib import contextmanager from typing import ( Any, @@ -76,65 +76,6 @@ def __init__(self, module: TModule): self.module = module -Reduction = Literal["none", "mean", "sum"] - - -def normalized_mse_loss( - input: Tensor, target: Tensor, reduction: Reduction = "mean", epsilon: float = 1e-6 -): - loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( - target, torch.zeros_like(target) + epsilon, reduction=reduction - ) - return loss - - -def mse_loss(input: Tensor, target: Tensor, reduction: Reduction = "mean", epsilon: float = 1e-6): - loss = F.mse_loss(input, target, reduction=reduction) - return loss - - -class NormalizedMSELoss(nn.modules.loss._Loss): - __constants__ = ["reduction", "epsilon"] - - def __init__(self, reduction: Reduction = "mean", epsilon: float = 1e-6) -> None: - super().__init__(None, None, reduction) - self.epsilon = epsilon - - def forward(self, input: Tensor, target: Tensor) -> Tensor: - loss = normalized_mse_loss( - input, - target, - cast(Reduction, self.reduction), - self.epsilon, - ) - return loss - - -def vectorwise_normalized_mse_loss(input: Tensor, target: Tensor, epsilon: float = 1e-6): - """ - Like normalized_mse_loss, but the input is treated as a multi-dimensional batch of vectors. - Normalization is done on each vector separately (the last dim), then results are averaged. - """ - return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) - - -def batched_normalized_mse_loss( - input: Tensor, target: Tensor, epsilon: float = 1e-6, batch_dims: Sequence[int] = (0,) -): - """ - Like normalized_mse_loss, but the input is treated as a batch of tensors. - Normalization is done on the non-batch dims, then results are averaged. - """ - norm_dims = list(set(range(input.ndim)) - set(batch_dims)) - norm_of_target_vectors = F.mse_loss( - target, torch.zeros_like(target) + epsilon, reduction="none" - ).mean(dim=norm_dims) - vectorwise_mse = F.mse_loss(input, target, reduction="none").mean(dim=norm_dims) - normalized_vectorwise_mse = vectorwise_mse / norm_of_target_vectors - loss = normalized_vectorwise_mse.mean() - return loss - - class ActivityContextMaxDepthException(Exception): pass @@ -216,20 +157,6 @@ def is_submodule_or_same(module_name: str, other_module_name: str) -> bool: return result -def reduce_losses(losses: Iterable[Tensor]) -> Tensor: - total_loss = None - for loss in losses: - if total_loss is None: - total_loss = loss - else: - total_loss += loss - - if total_loss is None: - return torch.Tensor(torch.nan) - - return total_loss - - fake_mode = FakeTensorMode( allow_non_fake_inputs=True, # allow_fallback_kernels=False, @@ -423,30 +350,6 @@ def has_fake_tensor(v: Any) -> bool: return result -@dynamo_skip -def is_real_tensor(t: Any) -> bool: - return isinstance(t, Tensor) and not t.is_meta and not isinstance(t, FakeTensor) - - -@dynamo_skip -def get_parent_module_name(module_name: str): - if "." not in module_name: - return "" - else: - return module_name.rsplit(".", 1)[0] - - -@dynamo_skip -def get_parent_module_names(module_name: str): - parent_module_names = set[str]() - - while len(module_name) > 0: - module_name = get_parent_module_name(module_name) - parent_module_names.add(module_name) - - return parent_module_names - - def _get_device_for_distributed( group: Optional[torch.distributed.ProcessGroup] = None, ) -> str: diff --git a/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py index 7f5a417786..e25c8e38d4 100644 --- a/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py +++ b/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py @@ -50,7 +50,7 @@ def calculate_subblock_memory( prefill_queue_size: int, n_embd: int, n_head: int, - weights_dtype: torch.dtype | str, + weights_dtype: torch.dtype, kv_cache_dtype: torch.dtype, allocate_prefill_query: bool, ) -> float | dict[str, float]: @@ -174,7 +174,7 @@ def calculate_attention_memory( prefill_queue_size: int, n_embd: int, n_head: int, - weights_dtype: torch.dtype | str, + weights_dtype: torch.dtype, kv_cache_dtype: torch.dtype, allocate_prefill_query: bool, ) -> dict[str, float]: @@ -221,8 +221,8 @@ def calculate_mamba_memory( mamba_config: MambaConfig, n_embd: int, batch_size: int, - weights_dtype: torch.dtype | str, - kv_cache_dtype: torch.dtype | str, + weights_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, ) -> int: return ( calculate_mamba_params(mamba_config, n_embd) * sizeof_dtype(weights_dtype) @@ -274,7 +274,7 @@ def _calculate_mamba_intermediates(mamba_config: MambaConfig) -> tuple[int, ...] def calculate_linear_memory( n_embd: int, - weights_dtype: torch.dtype | str, + weights_dtype: torch.dtype, ) -> float: return calculate_linear_params(n_embd) * sizeof_dtype(weights_dtype) / 2**20 @@ -288,7 +288,7 @@ def calculate_linear_params( def calculate_ffn_memory( ffn_config: FFNConfig, n_embd: int, - weights_dtype: torch.dtype | str, + weights_dtype: torch.dtype, ) -> float: num_params = calculate_ffn_params(ffn_config, n_embd) return num_params * sizeof_dtype(weights_dtype) / 2**20 diff --git a/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py b/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py index d3e73a0cf8..76e6c34281 100644 --- a/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py @@ -16,20 +16,16 @@ """Calc subblock stats to compute memory and runtime statistics for subblocks.""" -import os -from itertools import product - -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - import dataclasses import json +import os from functools import partial +from itertools import product from pathlib import Path from typing import Iterable, Optional, Type, TypeVar -import hydra +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + import pandas as pd import torch from immutabledict import immutabledict @@ -42,6 +38,7 @@ FFNConfig, SubblockConfig, ) +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement from modelopt.torch._compress.subblock_stats.calc_subblock_params_and_memory import ( calc_subblock_active_params, @@ -51,7 +48,6 @@ calculate_subblock_params, ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.robust_json import json_dump from modelopt.torch._compress.utils.parsing import format_global_config @@ -91,9 +87,7 @@ def calculate_subblock_stats( ) -> dict: is_calc_runtime = benchmark_iterations is not None if is_calc_runtime: - from puzzle_tools.subblock_stats.runtime_stats.calc_runtime_stats import ( - calc_runtime_ms_for_subblocks, - ) + raise NotImplementedError("Runtime stats calculation is not implemented yet") gpu = None if not torch.cuda.is_available() else torch.cuda.get_device_name() subblock_stats = { @@ -540,15 +534,3 @@ def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> di if len(matching_bf16_stats) == 1: return matching_bf16_stats[0] raise ValueError(f"Found more than 1 matching bf16 stats for {args=}") - - -@hydra.main("configs", version_base="1.3", config_name="search_space") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - mprint(format_global_config(cfg)) - launch_calc_subblock_stats(cfg) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() diff --git a/modelopt/torch/_compress/tools/bypassed_training/child_init.py b/modelopt/torch/_compress/tools/bypassed_training/child_init.py index d9ead79a1c..3e2c42f09c 100644 --- a/modelopt/torch/_compress/tools/bypassed_training/child_init.py +++ b/modelopt/torch/_compress/tools/bypassed_training/child_init.py @@ -39,7 +39,6 @@ ) from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.tools.logger import aprint, mprint -from modelopt.torch._compress.tools.runtime import IRuntime class GQAInitMode(Enum): @@ -331,7 +330,6 @@ def create_child_state_dict( new_config: DeciLMConfig, gqa_init_mode: GQAInitMode, ignore_fn: IgnoreFn = default_ignore_fn, - runtime: Optional[IRuntime] = Printer, mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, mlp_init_config: Optional[dict[str, Any]] = None, owned_block_indexes: Optional[set[int]] = None, diff --git a/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py index dbb4eac0c8..f06db92fbe 100644 --- a/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py @@ -220,47 +220,3 @@ def init_child_from_parent( mprint(f"Total core processing: {total_core_time:.2f}s") mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") mprint(f"=========================\n") - - -def parse_args(): - parser = argparse.ArgumentParser() - - # Arguments for single checkpoint creation - parser.add_argument("--parent_checkpoint_dir", type=str, required=True) - parser.add_argument("--model_config_overrides_json", type=str, required=True) - parser.add_argument("--output_checkpoint_dir", type=str, required=True) - parser.add_argument( - "--gqa_init_mode", type=str, default="AverageKV", choices=GQAInitMode._member_names_ - ) - parser.add_argument( - "--mlp_init_mode", type=str, default="Truncate", choices=MlpInitMode._member_names_ - ) - parser.add_argument("--mlp_init_config_yaml", type=str, default=None) - parser.add_argument( - "--linear_init_mode", type=str, default="FromTeacher", choices=LinearInitMode._member_names_ - ) - parser.add_argument( - "--hidden_size_init_mode", type=str, default=None, choices=HiddenSizeInitMode._member_names_ - ) - parser.add_argument("--channel_importance_path", type=str, required=False) - parser.add_argument("--target_hidden_sizes", type=int, nargs="+", required=False) - - args = parser.parse_args() - return args - - -if __name__ == "__main__": - args = parse_args() - - init_child_from_parent( - parent_checkpoint_dir=args.parent_checkpoint_dir, - model_config_overrides_json=args.model_config_overrides_json, - output_checkpoint_dir=args.output_checkpoint_dir, - gqa_init_mode=GQAInitMode(args.gqa_init_mode), - mlp_init_mode=MlpInitMode(args.mlp_init_mode), - mlp_init_config_yaml=args.mlp_init_config_yaml, - linear_init_mode=LinearInitMode(args.linear_init_mode), - hidden_size_init_mode=HiddenSizeInitMode(args.hidden_size_init_mode) - if args.hidden_size_init_mode - else None, - ) diff --git a/modelopt/torch/_compress/tools/hydra.py b/modelopt/torch/_compress/tools/hydra.py deleted file mode 100644 index 8c36d309e4..0000000000 --- a/modelopt/torch/_compress/tools/hydra.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from hydra import compose, initialize, initialize_config_dir -from omegaconf import DictConfig, OmegaConf - -""" -Utilities for hydra config initialization. -""" - - -def initialize_hydra_config_for_dir( - config_dir: str, config_name: str, overrides: list[str] -) -> DictConfig: - """Initialize a hydra config from an absolute path for a config directory - - Args: - config_dir (str): - config_name (str): - overrides (List[str]): - - Returns: - DictConfig: - """ - - with initialize_config_dir(version_base=None, config_dir=config_dir): - args = compose(config_name, overrides) - args._set_flag("allow_objects", True) - OmegaConf.resolve(args) # resolve object attributes - OmegaConf.set_struct(args, False) - - return args - - -def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig: - with initialize(version_base=None, config_path=config_path): - args = compose(config_name, overrides) - args._set_flag("allow_objects", True) - OmegaConf.resolve(args) # resolve object attributes - OmegaConf.set_struct(args, False) - - return args diff --git a/modelopt/torch/_compress/tools/runtime.py b/modelopt/torch/_compress/tools/runtime.py deleted file mode 100644 index 46f561a5d9..0000000000 --- a/modelopt/torch/_compress/tools/runtime.py +++ /dev/null @@ -1,556 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classes for torch distributed runtime management""" - -import os -import random -from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Iterator, Sequence -from contextlib import AbstractContextManager, suppress -from datetime import timedelta -from pathlib import Path -from typing import Literal, TypeVar, cast - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn -from torch.utils.data import DataLoader -from tqdm import tqdm -from typing_extensions import override - -PrepareModelsT = TypeVar("PrepareModelsT", bound=Sequence[nn.Module]) -PrepareDataLoaderT = TypeVar("PrepareDataLoaderT", bound=DataLoader) -CompileT = TypeVar("CompileT", bound=nn.Module) -Filter = ( - Literal["main_process", "last", "local_main_process", "local_last", "all"] - | list[int] - | set[int] - | Callable[[int], bool] -) - - -class IRuntime(ABC): - @abstractmethod - def setup(self) -> None: ... - - @abstractmethod - def cleanup(self) -> None: ... - - @abstractmethod - def autocast(self) -> AbstractContextManager: ... - - @abstractmethod - def wait_for_everyone(self) -> None: ... - - @abstractmethod - def set_seed(self, seed: int, device_specific: bool = False) -> int: ... - - @abstractmethod - def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: ... - - @abstractmethod - def prepare_train_dataloader( - self, train_dataloader: PrepareDataLoaderT - ) -> PrepareDataLoaderT: ... - - @abstractmethod - def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: ... - - @abstractmethod - def compile(self, model: CompileT) -> CompileT: ... - - @abstractmethod - def backward(self, loss: torch.Tensor) -> None: ... - - @abstractmethod - def clip_grad_norm_( - self, - parameters: Iterable[torch.Tensor] | torch.Tensor, - max_norm: float, - norm_type: float = 2, - ) -> torch.Tensor: ... - - @abstractmethod - def clip_grad_value_( - self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float - ) -> None: ... - - @abstractmethod - def save_state(self, path: str | Path) -> None: ... - - @abstractmethod - def load_state(self, path: str | Path) -> None: ... - - @abstractmethod - def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: ... - - @property - @abstractmethod - def sync_gradients(self) -> bool: ... - - @property - @abstractmethod - def device(self) -> torch.device: ... - - @property - @abstractmethod - def is_main_process(self) -> bool: ... - - @property - @abstractmethod - def is_local_main_process(self) -> bool: ... - - @property - @abstractmethod - def is_last_process(self) -> bool: ... - - @property - @abstractmethod - def is_local_last_process(self) -> bool: ... - - @property - @abstractmethod - def local_rank(self) -> int: ... - - @property - @abstractmethod - def global_rank(self) -> int: ... - - @property - @abstractmethod - def local_world_size(self) -> int: ... - - @property - @abstractmethod - def world_size(self) -> int: ... - - @property - @abstractmethod - def dtype(self) -> torch.dtype: ... - - def __enter__(self): - self.setup() - return self - - def __exit__(self, exc_type, exc_value, traceback): - # avoid barrier if exceution errored - if exc_type is None: - self.cleanup() - - # if exc_type is not None: - # raise exc_value - # Handle exceptions if necessary - # pass - - # def __del__(self): - # torch.distributed.barrier() - # torch.distributed.destroy_process_group() - - def check_filter(self, filter_: Filter): - return ( - filter_ == "all" - or (filter_ == "main_process" and self.is_main_process) - or (filter_ == "local_main_process" and self.is_local_main_process) - or (filter_ == "last" and self.is_last_process) - or (filter_ == "local_last" and self.is_local_last_process) - or (isinstance(filter_, (list, set)) and self.global_rank in filter_) - or (callable(filter_) and filter_(self.global_rank)) - ) - - def print( - self, *args, filter_: Filter = "main_process", rank_prefix=False, flush=True, **kwargs - ) -> None: - if not self.check_filter(filter_): - return - - if rank_prefix: - print(f"[global_rank={self.global_rank}]", *args, flush=flush, **kwargs) - else: - print(*args, flush=flush, **kwargs) - - def process_print( - self, *args, filter_: Filter = "all", rank_prefix=True, flush=True, **kwargs - ) -> None: - if not self.check_filter(filter_): - return - - if rank_prefix: - prefix = f"[global_rank={self.global_rank}]" - if len(args) == 1: # avoid out-of-order printing if possible - out = f"{prefix} {args[0]}" - args = (out,) - else: - args = (prefix, *args) - print(*args, flush=flush, **kwargs) - else: - print(*args, flush=flush, **kwargs) - - -class NativeDdpRuntime(IRuntime): - def __init__( - self, - dtype: torch.dtype = torch.float, - torch_distributed_timeout: timedelta | None = None, - ): - self._master_addr = os.environ["MASTER_ADDR"] - self._master_port = int(os.environ["MASTER_PORT"]) - self._local_rank = int(os.environ["LOCAL_RANK"]) - self._global_rank = int(os.environ["RANK"]) - self._local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - self._world_size = int(os.environ["WORLD_SIZE"]) - self._device = torch.device(self.local_rank) - self._dtype = dtype - self._torch_distributed_timeout = torch_distributed_timeout - - @override - def setup(self): - torch.cuda.set_device(self._device) - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group( - "cpu:gloo,cuda:nccl", timeout=self._torch_distributed_timeout - ) - input_tensors = [ - torch.tensor([0], dtype=torch.float32, device=self._device) - for _ in range(self.world_size) - ] - output_tensors = [ - torch.tensor([0], dtype=torch.float32, device=self._device) - for _ in range(self.world_size) - ] - torch.distributed.all_to_all(input_tensors, output_tensors) - - @override - def cleanup(self): - with suppress(Exception): - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - @override - def autocast(self) -> AbstractContextManager: - result = torch.autocast(device_type="cuda", dtype=self._dtype, enabled=True) - return result - - @override - def wait_for_everyone(self): - torch.distributed.barrier() - - @override - def set_seed(self, seed: int, device_specific: bool = False) -> int: - """ - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. - - Args: - seed (`int`): - The seed to set. - device_specific (`bool`, *optional*, defaults to `False`): - Whether to differ the seed on each device slightly with `self.process_index`. - """ - if device_specific: - seed += self.global_rank - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - return seed - - @override - def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: - assert all(isinstance(x, nn.Module) for x in models) - new_models = [nn.parallel.DistributedDataParallel(m) for m in models] - new_models = cast("PrepareModelsT", new_models) - return new_models # type: ignore[return-value] - - @override - def prepare_train_dataloader(self, train_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: - return train_dataloader - - @override - def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: - return val_dataloader - - @override - def compile(self, model: CompileT) -> CompileT: - result = torch.compile(model) - result = cast("CompileT", result) - return result - - @override - def backward(self, loss: torch.Tensor) -> None: - loss.backward() - - @override - def clip_grad_norm_( - self, - parameters: Iterable[torch.Tensor] | torch.Tensor, - max_norm: float, - norm_type: float = 2, - ) -> torch.Tensor: - result = torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) - return result - - @override - def clip_grad_value_( - self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float - ) -> None: - torch.nn.utils.clip_grad_value_(parameters, clip_value) - - @override - def save_state(self, path: str | Path) -> None: - pass - - @override - def load_state(self, path: str | Path) -> None: - pass - - @override - def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: - for _ in tqdm( - range(num_batches), desc=f"rank {self._global_rank}: skip_first_batches({num_batches=})" - ): - next(dataloader_iterator) - - @property - @override - def sync_gradients(self) -> bool: - return True - - @property - @override - def is_main_process(self) -> bool: - result = self.global_rank == 0 - return result - - @property - @override - def is_local_main_process(self) -> bool: - result = self.local_rank == 0 - return result - - @property - @override - def is_last_process(self) -> bool: - result = self.global_rank == self.world_size - 1 - return result - - @property - @override - def is_local_last_process(self) -> bool: - result = self.local_rank == self.local_world_size - 1 - return result - - @property - @override - def local_rank(self) -> int: - return self._local_rank - - @property - @override - def global_rank(self) -> int: - return self._global_rank - - @property - @override - def local_world_size(self) -> int: - return self._local_world_size - - @property - @override - def world_size(self) -> int: - return self._world_size - - @property - @override - def device(self) -> torch.device: - return self._device - - @property - @override - def dtype(self) -> torch.dtype: - return self._dtype - - @property - def master_addr(self) -> str: - return self._master_addr - - @property - def master_port(self) -> int: - return self._master_port - - -class BaseRuntime(IRuntime): - def __init__(self, dtype: torch.dtype = torch.float): - self._device = torch.device(self.local_rank) - self._dtype = dtype - - @override - def setup(self): - torch.cuda.set_device(self._device) - - @override - def cleanup(self): ... - - @override - def autocast(self) -> AbstractContextManager: - result = torch.autocast(device_type="cuda", dtype=self._dtype, enabled=True) - return result - - @override - def wait_for_everyone(self): ... - - @override - def set_seed(self, seed: int, device_specific: bool = False) -> int: - """ - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. - - Args: - seed (`int`): - The seed to set. - device_specific (`bool`, *optional*, defaults to `False`): - Whether to differ the seed on each device slightly with `self.process_index`. - """ - if device_specific: - seed += self.global_rank - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - return seed - - @override - def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: - assert all(isinstance(x, nn.Module) for x in models) - return models - - @override - def prepare_train_dataloader(self, train_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: - return train_dataloader - - @override - def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: - return val_dataloader - - @override - def compile(self, model: CompileT) -> CompileT: - result = torch.compile(model) - result = cast("CompileT", result) - return result - - @override - def backward(self, loss: torch.Tensor) -> None: - loss.backward() - - @override - def clip_grad_norm_( - self, - parameters: Iterable[torch.Tensor] | torch.Tensor, - max_norm: float, - norm_type: float = 2, - ) -> torch.Tensor: - result = torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) - return result - - @override - def clip_grad_value_( - self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float - ) -> None: - torch.nn.utils.clip_grad_value_(parameters, clip_value) - - @override - def save_state(self, path: str | Path) -> None: - pass - - @override - def load_state(self, path: str | Path) -> None: - pass - - @override - def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: - for _ in tqdm( - range(num_batches), desc=f"rank {self.global_rank}: skip_first_batches({num_batches=})" - ): - next(dataloader_iterator) - - @property - @override - def sync_gradients(self) -> bool: - return True - - @property - @override - def is_main_process(self) -> bool: - result = self.global_rank == 0 - return result - - @property - @override - def is_local_main_process(self) -> bool: - result = self.local_rank == 0 - return result - - @property - @override - def is_last_process(self) -> bool: - result = self.global_rank == self.world_size - 1 - return result - - @property - @override - def is_local_last_process(self) -> bool: - result = self.local_rank == self.local_world_size - 1 - return result - - @property - @override - def local_rank(self) -> int: - return 0 - - @property - @override - def global_rank(self) -> int: - return 0 - - @property - @override - def local_world_size(self) -> int: - return 1 - - @property - @override - def world_size(self) -> int: - return 1 - - @property - @override - def device(self) -> torch.device: - return self._device - - @property - @override - def dtype(self) -> torch.dtype: - return self._dtype - - @property - def master_addr(self) -> str | None: - return None - - @property - def master_port(self) -> int | None: - return None diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 8d1a222c89..7a247bbdf0 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -37,6 +37,7 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, @@ -45,7 +46,6 @@ ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.utils.utils import EmptyInitOnDevice @@ -144,14 +144,14 @@ def create_dummy_model( def load_and_shard_model( - runtime: IRuntime, checkpoint_path: str | Path, owned_block_indexes: set[int] | Literal["auto"] = "auto", model_config: DeciLMConfig | None = None, model_config_overrides: Mapping | None = None, + model_dtype: torch.dtype = torch.bfloat16, ) -> DeciLMForCausalLM: checkpoint_path = Path(checkpoint_path) - with runtime.device: + with torch.device(dist.local_rank()): if model_config is None: model_config = load_model_config( checkpoint_path, model_config_overrides, ignore_unexpected_config_keys=True @@ -159,14 +159,13 @@ def load_and_shard_model( if owned_block_indexes == "auto": owned_block_indexes = set( - np.array_split(np.arange(model_config.get_num_hidden_layers()), runtime.world_size)[ - runtime.global_rank + np.array_split(np.arange(model_config.get_num_hidden_layers()), dist.size())[ + dist.rank() ] ) mprint("Initializing model shards") model_shard = create_sharded_model( - runtime=runtime, model_config=model_config, owned_block_indexes=owned_block_indexes, ) @@ -182,7 +181,7 @@ def load_and_shard_model( shard_state_dict = load_sharded_state_dict( model_name_or_path=str(checkpoint_path), keys_to_load=shard_keys, - device=runtime.device, + device=torch.device(dist.local_rank()), ) new_names = set(shard_state_dict.keys()) @@ -196,15 +195,13 @@ def load_and_shard_model( model_shard.tie_weights() else: mprint("Loading state_dict in main process") - state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None + state_dict = load_state_dict(checkpoint_path) if dist.is_master() else None mprint("Distributing model to shards") - load_state_dict_to_shards( - runtime=runtime, model_shard=model_shard, loaded_state_dict=state_dict - ) + load_state_dict_to_shards(model_shard=model_shard, loaded_state_dict=state_dict) del state_dict - model_shard.type(runtime.dtype) + model_shard.type(model_dtype) params_on_meta_device = [ param_name @@ -212,14 +209,13 @@ def load_and_shard_model( if param.device == torch.device("meta") ] assert len(params_on_meta_device) == 0, ( - f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" + f"[global_rank={dist.rank()}] Couldn't load params {params_on_meta_device}" ) return model_shard def create_sharded_model( - runtime: IRuntime, model_config: DeciLMConfig, owned_block_indexes: set[int], device: str | torch.device | None = "meta", @@ -228,7 +224,7 @@ def create_sharded_model( if isinstance(device, str): device = torch.device(device) - runtime.wait_for_everyone() + dist.barrier() with EmptyInitOnDevice(device="meta", dtype=dtype): model = DeciLMForCausalLM(model_config) @@ -245,15 +241,18 @@ def create_sharded_model( def load_state_dict_to_shards( - runtime: IRuntime, model_shard: torch.nn.Module, loaded_state_dict: dict | None = None + model_shard: torch.nn.Module, loaded_state_dict: dict | None = None ) -> None: - from sewing_kit.utils import distributed_isend_obj, distributed_recv_obj + from modelopt.torch._compress.sewing_kit.utils import ( + distributed_isend_obj, + distributed_recv_obj, + ) model_shard.to("meta") local_state_dict_keys = list(model_shard.state_dict().keys()) - if runtime.is_main_process: - gathered_state_dict_keys = [None] * runtime.world_size + if dist.is_master(): + gathered_state_dict_keys = [None] * dist.size() torch.distributed.gather_object(local_state_dict_keys, gathered_state_dict_keys) assert loaded_state_dict is not None @@ -276,7 +275,7 @@ def load_state_dict_to_shards( torch.distributed.gather_object(local_state_dict_keys) shard_state_dict = distributed_recv_obj() - print(f"{runtime.global_rank=} loaded state_dict shard") + print(f"{dist.rank()} loaded state_dict shard") missing_keys, unexpected_keys = model_shard.load_state_dict( shard_state_dict, strict=False, assign=True @@ -284,20 +283,18 @@ def load_state_dict_to_shards( assert len(unexpected_keys) == 0 assert all("dummy_param" in key for key in missing_keys) - model_shard.to(runtime.device) + model_shard.cuda(dist.local_rank()) - runtime.wait_for_everyone() + dist.barrier() def save_sharded_model( - runtime: IRuntime, - model_shard: torch.nn.Module | dict[str, torch.Tensor], - out_path: str | Path, + model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path ): """ out_path is usually output_checkpoint_path / "model.safetensors" """ - runtime.wait_for_everyone() + dist.barrier() if isinstance(model_shard, torch.nn.Module): shard_state_dict = model_shard.state_dict() @@ -311,8 +308,8 @@ def save_sharded_model( weight.numel() * weight.element_size() for weight in shard_state_dict.values() ) - num_shards = runtime.world_size - idx = runtime.global_rank + num_shards = dist.size() + idx = dist.rank() out_path = Path(out_path) shard_file = out_path.with_stem(f"{out_path.stem}-{idx + 1:05d}-of-{num_shards:05d}") @@ -323,8 +320,8 @@ def save_sharded_model( "shard_file": str(shard_file), } - if runtime.is_main_process: - shard_metadatas = [{} for _ in range(runtime.world_size)] + if dist.is_master(): + shard_metadatas = [{} for _ in range(dist.size())] torch.distributed.gather_object(shard_metadata, shard_metadatas, dst=0) total_size = sum(x["total_shard_size"] for x in shard_metadatas) metadata = {"total_size": total_size} @@ -346,33 +343,7 @@ def save_sharded_model( else: torch.save(shard_state_dict, shard_file) - runtime.wait_for_everyone() - - -def save_sharded_state_dict( - state_dict: dict[str, torch.Tensor], - save_directory: str | Path, - max_shard_size: str = "10GB", -) -> None: - save_directory = Path(save_directory) - save_directory.mkdir(exist_ok=True, parents=True) - state_dict = {k: v.cpu() for k, v in state_dict.items()} - - state_dict_split = split_torch_state_dict_into_shards(state_dict, max_shard_size=max_shard_size) - - for shard_filename, param_names in tqdm( - state_dict_split.filename_to_tensors.items(), desc="saving sharded state dict" - ): - shard_path = save_directory / shard_filename - shard = {param_name: state_dict[param_name] for param_name in param_names} - safe_save_file(shard, shard_path, metadata={"format": "pt"}) - - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - index_path = save_directory / SAFE_WEIGHTS_INDEX_NAME - index_path.write_text(json.dumps(index, indent=2)) + dist.barrier() def load_sharded_state_dict( @@ -410,13 +381,3 @@ def _resolve_shard_paths(model_name_or_path: str) -> list[str]: def is_in_safetensors_format(checkpoint_dir: Path) -> bool: return len(list(checkpoint_dir.glob("*.safetensors"))) > 0 - - -def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]: - shard_paths = _resolve_shard_paths(model_name_or_path) - state_dict_shapes = {} - for safetensors_path in shard_paths: - with safe_open(safetensors_path, framework="pt") as f: - for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable - state_dict_shapes[key] = tuple(f.get_tensor(key).shape) - return state_dict_shapes diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index 8ec1d6f172..d3d71a4198 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -21,11 +21,10 @@ TODO: Consider moving this a separate module dedicated for scoring. """ -import argparse import textwrap from pathlib import Path -import torch.distributed +import torch from omegaconf import DictConfig from torch import nn from torch.utils.data import DataLoader @@ -36,12 +35,12 @@ PreTrainedTokenizerBase, ) +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( register_activation_hooks, ) from modelopt.torch._compress.tools.checkpoint_utils_hf import load_checkpoint from modelopt.torch._compress.tools.logger import aprint, mprint -from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.sharded_checkpoint_utils import load_and_shard_model from modelopt.torch._compress.utils.data.dataloaders import create_validation_dataloader from modelopt.torch._compress.utils.parsing import simple_parse_args_string @@ -51,12 +50,6 @@ ) from modelopt.torch._compress.utils.validation import calculate_losses -# #TODO:Import slack from root utils directory -# root_path = os.path.join(os.path.dirname(__file__), "..", "..") -# if root_path not in sys.path: -# sys.path.append(root_path) -# from utils.slack import send_slack_message - """ Two goals: 1) Calculate lm loss and token accuracy for a model. @@ -67,88 +60,89 @@ 2) Register hooks to capture the inputs and the outputs of pytorch modules. For example, to collect activations scores for various layers (ffn, layer_norm, etc.) that are used for pruning (ffn_hidden_size, embedding_pruning, etc). -See --activations_log_dir and --activation_hooks_kwargs args arguments. - +See activations_log_dir and activation_hooks_kwargs arguments. """ -def build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_name_or_path", - type=str, - default=None, - help="Required unless a model is passed to the function", - ) - parser.add_argument("--dataset_path", type=str, required=True) - - parser.add_argument("--output_dir_name", type=str, default="validation") - parser.add_argument( - "--calculate_full_score_ablations", - action="store_true", - help="Calculates a diverse suite of teacher similarity scores. " - "By default only a small suite is calculated, which is good for most use-cases.", - ) - - parser.add_argument("--tokenizer_name", type=str, default=None) - parser.add_argument("--data_column", type=str, default="content") - # TODO: Add help text for FIM rate, also for others less obvious args - parser.add_argument("--fim_rate", type=float, default=0) - parser.add_argument("--fim_spm_rate", type=float, default=0) - parser.add_argument("--eval_samples", type=int, default=None) - parser.add_argument("--block_size", type=int, default=4096) - parser.add_argument("--micro_batch_size", type=int, default=4) - parser.add_argument("--val_dataset_name", type=str, default="__auto__") - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--source_datasets_to_discard", nargs="+", type=str) - parser.add_argument("--bos_rate", type=float, default=1.0) - parser.add_argument("--shuffle_seed", type=int, default=None) - parser.add_argument("--varlen", action="store_true") - parser.add_argument("--pipeline_parallel", action="store_true") - parser.add_argument("--write_results", action="store_true") - parser.add_argument("--activations_log_dir", type=str, default=None) - parser.add_argument( - "--activation_hooks_kwargs", - type=str, - default=None, - help="Comma separated string arguments, e.g. `arg1=val1,arg2=val2`", - ) - parser.add_argument( - "--calc_losses_on_cpu", - action="store_true", - help="Very slow, not recommended. Can help avoid OOM.", - ) - return parser - - -def parse_args() -> argparse.Namespace: - parser = build_arg_parser() - args, unknown_args = parser.parse_known_args() - return args - - @torch.no_grad() def validate_model( - args: argparse.Namespace | DictConfig, + args: DictConfig, model: PreTrainedModel | None = None, tokenizer: PreTrainedTokenizerBase | None = None, target_hidden_states_per_batch: list[torch.Tensor] | None = None, return_hidden_states: bool = False, - runtime: IRuntime | None = None, + pipeline_parallel: bool = False, calculate_full_score_ablations: bool = False, val_dataloader: DataLoader | None = None, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + """Validate a language model on a dataset by calculating loss and optionally capturing activations. + + Args: + args: Configuration object containing the following attributes: + + Model Configuration: + - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. + Required unless model is passed directly. + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration: + - dataset_path (str): Path to the validation dataset. + - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. Uses all if None. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing: + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Activation Hooks: + - activations_log_dir (str, optional): Directory to log activation scores. If provided, + hooks will be registered to capture activations. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + If string, comma-separated format: "arg1=val1,arg2=val2". + + Execution Options: + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. + - write_results (bool): Write validation results to file. + + model: Pre-loaded model. If None, will be loaded from args.model_name_or_path. + tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args. + target_hidden_states_per_batch: Target hidden states for pipeline parallel evaluation. + return_hidden_states: Whether to return hidden states from the model. + pipeline_parallel: Enable pipeline parallelism for large models. + calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. + False calculates only a small suite for efficiency. + val_dataloader: Pre-created validation dataloader. If None, will be created from args. + + Returns: + A tuple containing: + - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). + - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. + Returns (None, None) if not on master rank. + """ + # convert model_dtype and autocast_dtype from string to torch.dtype + if isinstance(args.model_dtype, str): + args.model_dtype = getattr(torch, args.model_dtype.strip("torch.")) + if isinstance(args.autocast_dtype, str): + args.autocast_dtype = getattr(torch, args.autocast_dtype.strip("torch.")) + if val_dataloader is None: - val_dataloader = ( - prepare_dataloader(args, tokenizer) - if (runtime is None or runtime.is_main_process) - else None - ) + val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None validation_full_iters = ( args.eval_samples // args.micro_batch_size ) # model pipeline, single data rank - model = prepare_model(args, model, runtime) + model = prepare_model(args, model, pipeline_parallel) just_model_forward = False checkpoint_manager = None @@ -175,7 +169,6 @@ def validate_model( ) checkpoint_manager = ScoringCheckpointManager( checkpoint_dir=args.activations_log_dir, - runtime=runtime, activation_hooks=activation_hooks, checkpoint_interval=50, # Save every 50 batches ) @@ -190,7 +183,7 @@ def validate_model( just_model_forward = True model.lm_head = nn.Identity() - if runtime is None: + if not pipeline_parallel: losses, hidden_states_per_batch = calculate_losses( model=model, dataloader=val_dataloader, @@ -198,7 +191,6 @@ def validate_model( ) else: losses, hidden_states_per_batch = calculate_losses_pipeline( - runtime=runtime, stitched_model=model, dataloader=val_dataloader, target_hidden_states_per_batch=target_hidden_states_per_batch, @@ -207,6 +199,7 @@ def validate_model( calc_on_cpu=args.calc_losses_on_cpu, just_model_forward=just_model_forward, checkpoint_manager=checkpoint_manager, + autocast_dtype=args.autocast_dtype, ) if losses is not None: @@ -223,26 +216,23 @@ def validate_model( aprint(results_str) if args.write_results: Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str) - # TODO: send_slack_message(results_str) if activation_hooks is not None: - hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args, runtime) + hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args) return losses, hidden_states_per_batch def prepare_model( - args: argparse.Namespace, - model: PreTrainedModel | None = None, - runtime: IRuntime | None = None, + args: DictConfig, model: PreTrainedModel | None = None, pipeline_parallel: bool = False ) -> nn.Module: if model is None: assert args.model_name_or_path is not None - if runtime is not None: + if pipeline_parallel: model = load_and_shard_model( - runtime, args.model_name_or_path, model_config_overrides={"block_size": args.block_size}, + model_dtype=args.model_dtype, ) else: try: @@ -265,8 +255,7 @@ def prepare_model( def prepare_dataloader( - args: argparse.Namespace, - tokenizer: PreTrainedTokenizerBase | None = None, + args: DictConfig, tokenizer: PreTrainedTokenizerBase | None = None ) -> DataLoader: if tokenizer is None: tokenizer_name = getattr(args, "tokenizer_name", None) @@ -295,16 +284,3 @@ def prepare_dataloader( ) return val_dataloader - - -def main(): - args = parse_args() - if args.pipeline_parallel: - with NativeDdpRuntime(dtype=torch.bfloat16) as runtime: - validate_model(args=args, runtime=runtime) - else: - validate_model(args=args, runtime=None) - - -if __name__ == "__main__": - main() diff --git a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py index e947e97e4e..ca02998684 100644 --- a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py @@ -20,7 +20,6 @@ # mypy: ignore-errors -import argparse import json import shutil import warnings @@ -29,11 +28,12 @@ from typing import Optional import torch +from omegaconf import DictConfig from tqdm import tqdm from transformers import AutoTokenizer, PreTrainedTokenizerBase +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.replacement_library.build_replacement_library import infer_teacher_dir from modelopt.torch._compress.replacement_library.replacement_library import ReplacementLibrary from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement from modelopt.torch._compress.tools import validate_model @@ -45,7 +45,6 @@ save_checkpoint, save_safetensors_index, ) -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.tools.validation_utils import ( validate_model_and_extract_hidden_states, validate_model_with_teacher_similarity_metrics, @@ -54,64 +53,71 @@ from modelopt.torch._compress.utils.validate_runtime_pipeline import perform_pipeline_stitches """ -Usage: -====== - -Validate single_block_replacement_solutions -=========================================== - -( -export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"; -PUZZLE_DIR=".../Llama-3_2-1B-Instruct/parallel_puzzle"; - -torchrun --nproc-per-node=8 \ - -m modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements \ - --replacement_library_path ${PUZZLE_DIR}/replacement_library.json \ - --solutions_path ${PUZZLE_DIR}/single_sequence_replacement_solutions.json \ - --solutions_to_validate 0 \ - \ - --dataset_path .../v0.4/valid \ - --data_column conversation --block_size 8192 --seed 42 --shuffle_seed 444 --bos_rate 0.5 \ - --eval_samples 32 --micro_batch_size 1 \ - \ - --save_models \ - -) +Usage Example: +============== +Validate single_block_replacement_solutions by calling validate_puzzle_solutions() directly +with an args object containing the required attributes. See the function docstring for details. """ -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument("--replacement_library_path", type=parse_path, required=True) - parser.add_argument("--solutions_path", type=parse_path, required=True) - parser.add_argument("--teacher_dir", type=parse_path, default=None) - parser.add_argument("--solutions_to_validate", type=int, nargs="+", default=None) - parser.add_argument("--sort_solutions_by", type=str, default=None) - parser.add_argument("--bigger_is_better", action="store_true") - parser.add_argument("--skip_validation", action="store_true") - parser.add_argument("--save_models", action="store_true") - args, unknown_args = parser.parse_known_args() - if not args.skip_validation: - validation_args = validate_model.build_arg_parser().parse_args(unknown_args) - args = argparse.Namespace( - **{**validation_args.__dict__, **args.__dict__} - ) # if arg names overlap, the latter one wins - else: - args.block_size = None - - args.teacher_dir = _try_infer_teacher_dir(args.replacement_library_path, args.teacher_dir) - - args.tokenizer_name = getattr(args, "tokenizer_name", None) - if args.tokenizer_name is None: - args.tokenizer_name = args.teacher_dir - - return args - - @torch.no_grad() -def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> None: +def validate_puzzle_solutions(args: DictConfig) -> None: + """Validate puzzle solutions by applying layer replacements and evaluating model performance. + + Args: + args: Configuration object containing the following attributes: + + Puzzle Configuration (Required): + - replacement_library_path (Path): Path to the replacement library JSON file. + - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. + - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. + Validates all solutions if None. + - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. + - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. + - skip_validation (bool): If True, skip model validation and only save models if requested. + - save_models (bool): If True, save realized model checkpoints for each solution. + + Teacher/Tokenizer Configuration: + - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided. + - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. + + Model Configuration (Required if skip_validation=False): + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration (Required if skip_validation=False): + - dataset_path (str): Path to the validation dataset. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing (Required if skip_validation=False): + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Output Configuration: + - output_dir (Path, optional): Directory to save validation results. + Auto-generated from solutions_path if not provided. + + Execution Options (Optional if skip_validation=False): + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. + - write_results (bool): Write validation results to file. + - activations_log_dir (str, optional): Directory to log activation scores. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + + Returns: + None. Saves validation results and optionally model checkpoints to disk. + """ puzzle_solutions = load_puzzle_solutions( args.solutions_path, args.sort_solutions_by, args.bigger_is_better ) @@ -122,9 +128,7 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No tokenizer = _load_tokenizer(args) if not args.skip_validation: val_dataloader = ( - validate_model.prepare_dataloader(args, tokenizer) - if (runtime is None or runtime.is_main_process) - else None + validate_model.prepare_dataloader(args, tokenizer) if dist.is_master() else None ) output_dir = ( @@ -137,18 +141,16 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No teacher_hidden_states = None if (args.teacher_dir is not None) and (not args.skip_validation): - teacher_model = replacement_library.load_checkpoint( - args.teacher_dir, runtime.world_size, runtime.global_rank - ) - teacher_model.to(runtime.device) - stitched_model = perform_pipeline_stitches(teacher_model, runtime) + teacher_model = replacement_library.load_checkpoint(args.teacher_dir) + teacher_model.cuda(dist.local_rank()) + stitched_model = perform_pipeline_stitches(teacher_model) teacher_hidden_states = validate_model_and_extract_hidden_states( args, stitched_model, tokenizer, output_dir, model_name="teacher", - runtime=runtime, + pipeline_parallel=True, val_dataloader=val_dataloader, ) @@ -160,9 +162,7 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No # realizable_as_symlinks = False model_config = replacement_library.create_model_config(layer_replacements) if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): - model = replacement_library.load_model( - layer_replacements, runtime.world_size, runtime.global_rank - ) + model = replacement_library.load_model(layer_replacements) model_config = model.config if args.save_models: @@ -171,10 +171,10 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No / f"solution_{i_solution}" ) - model_config.dtype = "bfloat16" + model_config.dtype = args.model_dtype model_config.architectures = ["DeciLMForCausalLM"] if realizable_as_symlinks: - if runtime.global_rank == 0: + if dist.is_master(): save_checkpoint_as_symlinks( layer_replacements, model_config, checkpoint_dir, replacement_library ) @@ -184,13 +184,11 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No copy_tokenizer(args.tokenizer_name, checkpoint_dir) copy_hf_code(checkpoint_dir) - runtime.wait_for_everyone() - - runtime.wait_for_everyone() + dist.barrier() if not args.skip_validation: - model.to(runtime.device) - stitched_model = perform_pipeline_stitches(model, runtime) + model.cuda(dist.local_rank()) + stitched_model = perform_pipeline_stitches(model) validate_model_with_teacher_similarity_metrics( args, stitched_model, @@ -199,11 +197,11 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No output_dir, model_name=f"solution_{i_solution}", extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution}, - runtime=runtime, + pipeline_parallel=True, val_dataloader=val_dataloader, ) - runtime.wait_for_everyone() + dist.barrier() def can_realize_as_symlinks(layer_replacements: list[dict]) -> bool: @@ -255,23 +253,7 @@ def copy_hf_code(checkpoint_dir: Path) -> None: shutil.copy(file, checkpoint_dir / file.name) -def _try_infer_teacher_dir( - replacement_library_path: str | Path, - teacher_dir: str | Path | None, -) -> Path | None: - if teacher_dir is not None: - return teacher_dir - - try: - teacher_dir = infer_teacher_dir( - master_puzzle_dir=Path(replacement_library_path).parent, teacher_checkpoint_dir=None - ) - return teacher_dir - except: - return None - - -def _load_tokenizer(args: argparse.Namespace) -> PreTrainedTokenizerBase: +def _load_tokenizer(args: DictConfig) -> PreTrainedTokenizerBase: tokenizer = None if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) @@ -324,7 +306,3 @@ def load_puzzle_solutions( print(f"sorted solutions by {sort_solutions_by}. {vals[:10]=} {vals[-10:]=}") return puzzle_solutions - - -if __name__ == "__main__": - validate_puzzle_solutions(args=parse_args()) diff --git a/modelopt/torch/_compress/tools/validation_utils.py b/modelopt/torch/_compress/tools/validation_utils.py index 907dee4029..6f0b1fcb5d 100644 --- a/modelopt/torch/_compress/tools/validation_utils.py +++ b/modelopt/torch/_compress/tools/validation_utils.py @@ -20,31 +20,32 @@ # mypy: ignore-errors -import argparse from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from omegaconf import DictConfig, OmegaConf from torch import nn from transformers import PreTrainedTokenizerBase -from modelopt.torch._compress.sewing_kit import StitchedModule +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools import validate_model from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.robust_json import json_dump -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.utils.validation import LowMemorySparseTensor +if TYPE_CHECKING: + from modelopt.torch._compress.sewing_kit import StitchedModule + def validate_model_and_extract_hidden_states( - args: argparse.Namespace, - model: nn.Module | StitchedModule, + args: DictConfig, + model: "nn.Module | StitchedModule", tokenizer: PreTrainedTokenizerBase, - output_dir: Union[str, Path], + output_dir: str | Path, model_name: str, extra_payload: Optional[dict[str, Any]] = None, - runtime: Optional[IRuntime] = None, + pipeline_parallel: bool = False, val_dataloader=None, ) -> list[torch.Tensor | LowMemorySparseTensor]: mprint(f""" @@ -59,10 +60,10 @@ def validate_model_and_extract_hidden_states( model, tokenizer, return_hidden_states=True, - runtime=runtime, + pipeline_parallel=pipeline_parallel, val_dataloader=val_dataloader, ) - if runtime is None or runtime.is_last_process: + if dist.is_last_process(): output_dir = output_dir if (output_dir is not None) else args.bypass_dir extra_payload = extra_payload if (extra_payload is not None) else dict() write_results(output_dir, model_name, args, {**losses, **extra_payload}) @@ -70,14 +71,14 @@ def validate_model_and_extract_hidden_states( def validate_model_with_teacher_similarity_metrics( - args: argparse.Namespace, - model: nn.Module | StitchedModule, + args: DictConfig, + model: "nn.Module | StitchedModule", tokenizer: PreTrainedTokenizerBase, target_hidden_states_per_batch: list[torch.Tensor], - output_dir: Union[str, Path], + output_dir: str | Path, model_name: str, extra_payload: Optional[dict[str, Any]] = None, - runtime: Optional[IRuntime] = None, + pipeline_parallel: bool = False, calculate_full_score_ablations: bool = False, val_dataloader=None, ) -> None: @@ -94,20 +95,17 @@ def validate_model_with_teacher_similarity_metrics( model, tokenizer, target_hidden_states_per_batch=target_hidden_states_per_batch, - runtime=runtime, + pipeline_parallel=pipeline_parallel, calculate_full_score_ablations=calculate_full_score_ablations, val_dataloader=val_dataloader, ) - if runtime is None or runtime.is_last_process: + if dist.is_last_process(): extra_payload = extra_payload if (extra_payload is not None) else dict() write_results(output_dir, model_name, args, {**losses, **extra_payload}) def write_results( - output_dir: Union[str, Path], - result_name: str, - args: argparse.Namespace, - payload: dict[str, Any], + output_dir: str | Path, result_name: str, args: DictConfig, payload: dict[str, Any] ) -> None: output_path = Path(output_dir) / f"{result_name}.json" output_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/modelopt/torch/_compress/utils/checkpoint_manager.py b/modelopt/torch/_compress/utils/checkpoint_manager.py index b96fd21a56..7a27334469 100644 --- a/modelopt/torch/_compress/utils/checkpoint_manager.py +++ b/modelopt/torch/_compress/utils/checkpoint_manager.py @@ -22,30 +22,27 @@ from pathlib import Path from typing import Any, Dict, Optional +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.logger import aprint, mprint class ScoringCheckpointManager: """Manages checkpointing for activation hook scoring with periodic saves.""" - def __init__( - self, checkpoint_dir: str, runtime, activation_hooks=None, checkpoint_interval: int = 100 - ): + def __init__(self, checkpoint_dir: str, activation_hooks=None, checkpoint_interval: int = 100): """ Initialize checkpoint manager. Args: checkpoint_dir: Directory to save checkpoints - runtime: Runtime object for distributed processing activation_hooks: Dictionary of activation hooks to manage checkpoint_interval: Save checkpoint every N batches """ self.checkpoint_dir = Path(checkpoint_dir) - self.runtime = runtime self.activation_hooks = activation_hooks self.checkpoint_interval = checkpoint_interval - self.rank = runtime.global_rank if runtime is not None else 0 - self.is_main_process = runtime is None or runtime.is_main_process + self.rank = dist.rank() + self.is_main_process = dist.is_master() # Debug: Log checkpoint manager initialization hook_count = len(activation_hooks) if activation_hooks else 0 @@ -200,9 +197,7 @@ def update_progress(self, batch_idx: int, total_batches: int): ActivationsHook, ) - saved_path = ActivationsHook.save_hook_states( - self.activation_hooks, self.checkpoint_dir, self.runtime - ) + ActivationsHook.save_hook_states(self.activation_hooks, self.checkpoint_dir) except Exception as e: mprint(f"Warning: Failed to save hook states: {e}") @@ -211,8 +206,7 @@ def update_progress(self, batch_idx: int, total_batches: int): self.save_checkpoint() # Synchronize all ranks after checkpointing - if self.runtime is not None: - self.runtime.wait_for_everyone() + dist.barrier() def save_checkpoint(self): """ @@ -260,7 +254,7 @@ def finalize(self): ) saved_path = ActivationsHook.save_hook_states( - self.activation_hooks, self.checkpoint_dir, self.runtime + self.activation_hooks, self.checkpoint_dir ) mprint(f"Final hook states saved to {saved_path}") except Exception as e: @@ -273,5 +267,4 @@ def finalize(self): mprint(f"Scoring completed and finalized: {self.total_batches} batches processed") # Synchronize all ranks after finalization - if self.runtime is not None: - self.runtime.wait_for_everyone() + dist.barrier() diff --git a/modelopt/torch/_compress/utils/data/dataloaders.py b/modelopt/torch/_compress/utils/data/dataloaders.py index 584e32480b..865ad89fbc 100644 --- a/modelopt/torch/_compress/utils/data/dataloaders.py +++ b/modelopt/torch/_compress/utils/data/dataloaders.py @@ -17,7 +17,6 @@ DataLoader utilities for language model training and validation. """ -import os from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import Protocol, TypeVar @@ -74,58 +73,6 @@ def load_streaming_fn( return dataset -def create_train_dataloader( - accelerator: Accelerator, - seed: int, - tokenizer: PreTrainedTokenizerBase, - block_size: int, - dataset: str | Mapping[str, Dataset], - content_field: str, - fim_rate: float, - fim_spm_rate: float, - micro_batch_size: int, - load_dataset_fn: LoadDatasetFn = load_from_disk_fn, - dataset_name="train", - keep_in_memory: bool = False, - shuffle_train_data_seed: int | None = None, - source_datasets_to_discard: Sequence[str] = (), - bos_rate: float = 1.0, - varlen: bool = True, -): - mprint(f"\ncreate_train_dataloader on rank {accelerator.process_index}") - if isinstance(dataset, str): - dataset = load_dataset_fn(dataset, content_field, keep_in_memory) - - train_data = dataset[dataset_name] - if shuffle_train_data_seed is not None: - train_data = train_data.shuffle(seed=shuffle_train_data_seed) - - train_dataset = ConstantLengthDataset( - tokenizer, - train_data, - infinite=True, - seq_length=block_size * micro_batch_size if varlen else block_size, - content_field=content_field, - fim_rate=fim_rate, - fim_spm_rate=fim_spm_rate, - seed=seed, - source_datasets_to_discard=source_datasets_to_discard, - bos_rate=bos_rate, - # return_cu_seqlens=varlen, - # seqlen_cap=block_size if varlen else None - ) - - train_dataloader = DataLoader( - train_dataset, - batch_size=1 if varlen else micro_batch_size, - pin_memory=True, - collate_fn=collate_fn_with_none_support, - num_workers=os.cpu_count() // 2 // 8, - ) - - return train_dataloader - - def create_validation_dataloader( accelerator: Accelerator | None, seed: int, @@ -231,75 +178,6 @@ def realize_dataset_in_memory(dataset: IterableDataset, eval_samples: int | None return offloaded_dataset -def create_dataloaders( - accelerator: Accelerator, - seed: int, - tokenizer: PreTrainedTokenizerBase, - block_size: int, - dataset_path: str, - content_field: str, - fim_rate: float, - fim_spm_rate: float, - micro_batch_size: int, - val_micro_batch_size: int | None = None, - eval_samples: int | None = None, - load_dataset_fn: LoadDatasetFn = load_from_disk_fn, - train_dataset_name: str = "train", - val_dataset_name: str = "__auto__", - disable_validation: bool = False, - keep_in_memory: bool = False, - shuffle_train_data_seed: int | None = None, - source_datasets_to_discard: Sequence[str] = (), - bos_rate: float = 1.0, - varlen: bool = True, -): - if val_micro_batch_size is None: - val_micro_batch_size = micro_batch_size - - dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory=keep_in_memory) - - train_dataloader = create_train_dataloader( - accelerator, - seed, - tokenizer, - block_size, - dataset, - content_field, - fim_rate, - fim_spm_rate, - micro_batch_size, - load_dataset_fn, - train_dataset_name, - shuffle_train_data_seed=shuffle_train_data_seed, - source_datasets_to_discard=source_datasets_to_discard, - bos_rate=bos_rate, - varlen=varlen, - ) - - if not disable_validation: - val_dataloader = create_validation_dataloader( - accelerator, - seed, - tokenizer, - block_size, - dataset, - content_field, - fim_rate, - fim_spm_rate, - val_micro_batch_size, - eval_samples, - load_dataset_fn, - val_dataset_name, - source_datasets_to_discard=source_datasets_to_discard, - bos_rate=bos_rate, - varlen=varlen, - ) - else: - val_dataloader = None - - return train_dataloader, val_dataloader - - TensorT = TypeVar("TensorT", bound=torch.Tensor) diff --git a/modelopt/torch/_compress/utils/dist_utils.py b/modelopt/torch/_compress/utils/dist_utils.py deleted file mode 100644 index 84f8f2bab1..0000000000 --- a/modelopt/torch/_compress/utils/dist_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch.distributed as dist - - -def is_distributed(): - """ - From torchtune.utils.is_distributed() : https://docs.pytorch.org/torchtune/0.2/generated/torchtune.utils.is_distributed.html - """ - port = os.environ.get("MASTER_PORT", "") - addr = os.environ.get("MASTER_ADDR", "") - size = int(os.environ.get("WORLD_SIZE", 1)) - rank = int(os.environ.get("RANK", -1)) - avlb = dist.is_available() - return bool(port and addr and size > 1 and rank >= 0 and avlb) diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py index d03ea80403..62b7678ebc 100644 --- a/modelopt/torch/_compress/utils/utils.py +++ b/modelopt/torch/_compress/utils/utils.py @@ -63,7 +63,7 @@ def raise_unknown_subblock_config_error(subblock_config: Any) -> None: ) -def sizeof_dtype(dtype: torch.dtype | str) -> int | float: +def sizeof_dtype(dtype: torch.dtype) -> int | float: """Return the size in bytes of the given data type. TODO: Consider a better place for this function. diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py index aa8a4f304b..b3be70644b 100644 --- a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -23,16 +23,12 @@ """ # mypy: ignore-errors -from statistics import mean - import numpy as np import torch -import torch.distributed -import wandb from torch.utils.data import DataLoader from tqdm import tqdm -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMForCausalLM, LMHead, @@ -52,148 +48,10 @@ fake_tensor, ) from modelopt.torch._compress.tools.checkpoint_utils import init_module_with_state_dict -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.tools.sharded_checkpoint_utils import DummyBlock from modelopt.torch._compress.utils.validation import _organize_outputs, calculate_batch_outputs -@torch.no_grad() -def validate_pipeline_inner( - runtime: IRuntime, - stitched_model: StitchedModule, - val_dataloader: DataLoader | None, -) -> float: - if runtime.is_main_process: - assert val_dataloader.batch_size is not None - model_device = next(stitched_model.parameters()).device - - with runtime.autocast(): - stitched_model.eval() - - all_logits: list[torch.Tensor] = [] - all_targets: list[torch.Tensor] = [] - losses: list[float] = [] - - if runtime.is_main_process: - input_ids: torch.Tensor - targets: torch.Tensor - - for i_batch, batch in enumerate(tqdm(val_dataloader)): - input_ids, targets = ( - batch["input_ids"].to(model_device), - batch["targets"].to(model_device), - ) - - if i_batch == 0: - num_batches = len(val_dataloader) - seq_len = input_ids.shape[1] - if torch.distributed.is_initialized(): - torch.distributed.broadcast_object_list([(num_batches, seq_len)]) - - all_targets.append(targets.cpu()) - - output = stitched_model({}, {}, input_ids) - logits = output.captured_outputs.get("model_output") - logits = getattr(logits, "logits", logits) - - if logits is not None: - all_logits.append(logits.cpu()) - - del output, logits - - if len(all_targets) > 0: - distributed_send_obj(all_targets, dst=runtime.world_size - 1) - - else: - obj_list: list[tuple] = [None] - torch.distributed.broadcast_object_list(obj_list) - num_batches, seq_len = obj_list[0] - - fake_input_ids = fake_tensor(1, seq_len, dtype=runtime.dtype) - - for i in range(num_batches): - output = stitched_model({}, {}, fake_input_ids) - logits = output.captured_outputs.get("model_output") - logits = getattr(logits, "logits", logits) - if logits is not None: - all_logits.append(logits.cpu()) - del output, logits - - if len(all_targets) == 0 and runtime.global_rank == runtime.world_size - 1: - all_targets = distributed_recv_obj(src=0) - - torch.distributed.barrier() - - if len(all_logits) > 0: - for logits, targets in zip(all_logits, all_targets): - logits = logits.to("cuda") - targets = targets.to("cuda") - logit_losses = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" - ) - - mean_losses = logit_losses.cpu().mean(dim=-1) - losses.extend(mean_losses.tolist()) - - val_loss = mean(losses) - - if not runtime.is_main_process: - distributed_send_obj(val_loss, dst=0) - elif runtime.is_main_process: - val_loss = distributed_recv_obj() - else: - val_loss = float("nan") - - stitched_model.train() - - loss_list = [val_loss] - torch.distributed.broadcast_object_list(loss_list) - val_loss = loss_list[0] - - return val_loss - - -@torch.no_grad() -def validate_pipeline( - runtime: IRuntime, - stitched_model: StitchedModule, - model_config: DeciLMConfig, - val_dataloader: DataLoader, - iter_num: int | None = None, - max_iters: int | None = None, - model_name: str | None = None, - enable_print: bool = True, - enable_wandb_log: bool = False, - # pad_to_batchsize: bool = True, -) -> float: - if enable_print: - mprint("Validating ...") - - val_loss = validate_pipeline_inner( - runtime=runtime, - stitched_model=stitched_model, - val_dataloader=val_dataloader, - ) - - if runtime.is_main_process: - key = "val/loss" if model_name is None else f"val/{model_name}_loss" - if enable_print: - prefix = "" - if iter_num is not None: - prefix += f"iter {iter_num}" - if max_iters is not None: - prefix += f"/{max_iters}" - prefix += " - " - mprint(f"{prefix}{key}: {val_loss:.4f}") - if enable_wandb_log: - wandb.log({key: val_loss}, step=iter_num) - - runtime.wait_for_everyone() - - return val_loss - - class HiddenStatesAndLMHead(list): def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): super().__init__(hidden_states) @@ -202,7 +60,6 @@ def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Ten @torch.no_grad() def calculate_losses_pipeline( - runtime: IRuntime, stitched_model: StitchedModule | DeciLMForCausalLM, dataloader: DataLoader | None, target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, @@ -211,6 +68,7 @@ def calculate_losses_pipeline( calc_on_cpu: bool = False, just_model_forward: bool = False, checkpoint_manager=None, + autocast_dtype: torch.dtype = torch.bfloat16, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: """ Do model forward on each batch and calculate LM loss. @@ -232,27 +90,27 @@ def calculate_losses_pipeline( """ if isinstance(stitched_model, DeciLMForCausalLM): - stitched_model = perform_pipeline_stitches(stitched_model, runtime) + stitched_model = perform_pipeline_stitches(stitched_model) params = list(stitched_model.parameters()) model_device = params[0].device if params else "cpu" # Pre-populate outputs with dummy values for skipped batches start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 - if runtime.is_last_process: + if dist.is_last_process(): outputs = [{"lm_loss": [0.0]}] * start_batch else: outputs = None - if runtime.is_main_process: + if dist.is_master(): all_input_ids, all_targets = zip( *[(batch["input_ids"], batch["targets"]) for batch in dataloader] ) - if runtime.world_size > 1: - distributed_send_obj(all_targets, dst=runtime.world_size - 1) + if dist.size() > 1: + distributed_send_obj(all_targets, dst=dist.size() - 1) - if runtime.is_last_process: - if runtime.world_size > 1: + if dist.is_last_process(): + if dist.size() > 1: all_targets = distributed_recv_obj(src=0) lm_head: LMHead = next( @@ -268,37 +126,37 @@ def calculate_losses_pipeline( {"weight": lm_head_weights}, LMHead, *lm_head_weights.shape[::-1], bias=False ) - if runtime.is_main_process: + if dist.is_master(): num_batches = len(all_input_ids) seq_len = all_input_ids[0].shape[1] - if runtime.world_size > 1: + if dist.size() > 1: torch.distributed.broadcast_object_list([num_batches, seq_len]) # Create progress bar with sliced range starting from checkpoint position desc = ( - f"[rank {runtime.global_rank}] calculate_losses_pipeline(" + f"[rank {dist.rank()}] calculate_losses_pipeline(" f"{(target_hidden_states_per_batch is None)=}, {return_hidden_states=}, {num_batches=})" ) progress_bar = tqdm(range(start_batch, num_batches), desc=desc) else: obj_list = [None, None] - if runtime.world_size > 1: + if dist.size() > 1: torch.distributed.broadcast_object_list(obj_list) num_batches, seq_len = obj_list progress_bar = range(start_batch, num_batches) stitched_model.eval() - with runtime.autocast(): + with torch.autocast(device_type="cuda", dtype=autocast_dtype): for i_batch in progress_bar: - if runtime.is_main_process: + if dist.is_master(): input_ids = all_input_ids[i_batch].to(model_device) else: input_ids = fake_tensor(1, seq_len, dtype=torch.long) output = stitched_model({}, {}, input_ids) - if runtime.is_last_process: + if dist.is_last_process(): logits = output.captured_outputs.get("model_output") logits = getattr(logits, "logits", logits) hidden_states = output.captured_outputs.get("hidden_states") @@ -340,14 +198,11 @@ def calculate_losses_pipeline( hidden_states_per_batch, lm_head.weight.cpu() ) - runtime.wait_for_everyone() + dist.barrier() return losses, hidden_states_per_batch -def perform_pipeline_stitches( - model: DeciLMForCausalLM, - runtime: IRuntime, -) -> StitchedModule: +def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: target = ModuleTarget("module", model) stitcher = Needle() @@ -356,10 +211,10 @@ def perform_pipeline_stitches( ) first_block, last_block = is_real_block.min(), is_real_block.max() - if runtime.global_rank != 0: + if dist.rank() != 0: # receive activations from previous rank stitcher.stitch( - RemoteTarget(peer_rank=runtime.global_rank - 1).value( + RemoteTarget(peer_rank=dist.rank() - 1).value( name="activations", adapter=lambda x: InputArgs(x) ), target.input( @@ -370,11 +225,11 @@ def perform_pipeline_stitches( ), ) - if not runtime.is_last_process: + if not dist.is_last_process(): # send activations to next rank stitcher.stitch( target.output(f"model.layers.{last_block}"), - RemoteTarget(peer_rank=runtime.global_rank + 1).value(name="activations"), + RemoteTarget(peer_rank=dist.rank() + 1).value(name="activations"), ) else: # register model output diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py index 662ae4a2b6..d970105e68 100644 --- a/modelopt/torch/_compress/utils/validation.py +++ b/modelopt/torch/_compress/utils/validation.py @@ -24,14 +24,10 @@ import functools import math from enum import Enum -from statistics import mean import numpy as np import torch -import torch.distributed import torch.nn.functional as F -import wandb -from accelerate import Accelerator from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm @@ -39,159 +35,6 @@ from typing_extensions import Self from modelopt.torch._compress.tools import kd_model -from modelopt.torch._compress.utils.data.dataloaders import create_padded_tensor - - -@torch.no_grad() -def _validate_single( - accelerator: Accelerator, - model: torch.nn.Module, - rope_cache: torch.Tensor | None, - val_dataloader: DataLoader, - pad_to_batchsize: bool = True, - compute_kl_div: bool = False, - varlen: bool = False, - concat_token_id: int | None = None, -) -> list[float]: - assert val_dataloader.batch_sampler.batch_size is not None - desired_batch_size = val_dataloader.batch_sampler.batch_size - - with accelerator.device, accelerator.autocast(): - model.eval() - - losses: list[float] = [] - - input_ids: torch.LongTensor - targets: torch.LongTensor - is_first_batch = True - for batch in tqdm(val_dataloader, disable=not accelerator.is_main_process): - if is_first_batch: - print( - f"First batch, device {accelerator.device}, input_ids: {batch['input_ids'][:4]}" - ) - is_first_batch = False - input_ids, targets = ( - batch["input_ids"].to(accelerator.device), - batch["targets"].to(accelerator.device), - ) - batch_size = input_ids.size(0) - - if pad_to_batchsize: - input_ids = create_padded_tensor( - input_ids, (desired_batch_size, *input_ids.shape[1:]) - ) - targets = create_padded_tensor(targets, (desired_batch_size, *targets.shape[1:])) - - if rope_cache is not None: - logits = model( - input_ids, rope_cache=rope_cache, varlen=varlen, concat_token_id=concat_token_id - ) - else: - logits = model(input_ids) - - if hasattr(logits, "logits"): # For HF models - logits = logits.logits - - if isinstance(logits, tuple): # For KD - logits, teacher_logits, kd_block_loss, kd_logits_loss = logits - - if compute_kl_div: - # assumes kd_logits_loss has entry for each batch item - batch_losses = kd_logits_loss[:batch_size] - else: - batch_losses = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" - )[:batch_size].mean(dim=-1) - - losses.extend(batch_losses.tolist()) - - model.train() - - return losses - - -@torch.no_grad() -def validate_parallel( - accelerator: Accelerator, - model: torch.nn.Module, - rope_cache: torch.Tensor | None, - val_dataloader: DataLoader, - pad_to_batchsize: bool = True, - compute_kl_div: bool = False, - varlen: bool = False, - concat_token_id: int | None = None, -) -> float: - losses = _validate_single( - accelerator=accelerator, - model=model, - rope_cache=rope_cache, - val_dataloader=val_dataloader, - pad_to_batchsize=pad_to_batchsize, - compute_kl_div=compute_kl_div, - varlen=varlen, - concat_token_id=concat_token_id, - ) - - results = [float("nan")] - if accelerator.is_main_process: - gathered_results = [[float("nan")]] * accelerator.num_processes - torch.distributed.gather_object(losses, gathered_results) - gathered_losses = [l for result in gathered_results for l in result] - results[0] = mean(gathered_losses) - else: - torch.distributed.gather_object(losses) - - torch.distributed.broadcast_object_list(results) - val_loss = results[0] - - return val_loss - - -@torch.no_grad() -def validate( - accelerator: Accelerator, - model: torch.nn.Module, - rope_cache: torch.Tensor | None, - val_dataloader: DataLoader, - iter_num: int | None = None, - max_iters: int | None = None, - model_name: str | None = None, - enable_print: bool = True, - enable_wandb_log: bool = False, - pad_to_batchsize: bool = True, - compute_kl_div: bool = False, - varlen: bool = False, - concat_token_id: int | None = None, -) -> float: - if enable_print: - accelerator.print("Validating ...") - - val_loss = validate_parallel( - accelerator=accelerator, - model=model, - rope_cache=rope_cache, - val_dataloader=val_dataloader, - pad_to_batchsize=pad_to_batchsize, - compute_kl_div=compute_kl_div, - varlen=varlen, - concat_token_id=concat_token_id, - ) - - if accelerator.is_main_process: - key = "val/loss" if model_name is None else f"val/{model_name}_loss" - if enable_print: - prefix = "" - if iter_num is not None: - prefix += f"iter {iter_num}" - if max_iters is not None: - prefix += f"/{max_iters}" - prefix += " - " - accelerator.print(f"{prefix}{key}: {val_loss:.4f}", show_delta=True) - if enable_wandb_log: - wandb.log({key: val_loss}, step=iter_num) - accelerator.wait_for_everyone() - - return val_loss class UnshardedLowMemorySparseTensor: @@ -325,37 +168,6 @@ def calculate_losses( return losses, None -def calc_entropy(logits: torch.Tensor) -> torch.Tensor: - """ - Returns per-token entropy given a logits tensor of shape [batch_size x seq_len x vocab_size]. - The output will have shape [batch_size x seq_len]. - """ - # Convert logits to log-probabilities - log_probs = F.log_softmax(logits, dim=-1) # shape: [B x T x V] - - # Compute probabilities from log-probabilities - probs = torch.exp(log_probs) # shape: [B x T x V] - - # Entropy calculation: sum over V of (- p * log p) - ent = -torch.sum(probs * log_probs, dim=-1) # shape: [B x T] - - return ent - - -def confidence_max_softmax(logits: torch.Tensor) -> torch.Tensor: - """ - Returns per-token max-softmax confidence given a logits tensor of shape [batch_size x seq_len x vocab_size]. - The output will have shape [batch_size x seq_len]. - """ - # Compute softmax probabilities - probs = F.softmax(logits, dim=-1) # shape: [B x T x V] - - # Take the maximum probability along the vocabulary dimension - max_confidence = torch.max(probs, dim=-1).values # shape: [B x T] - - return max_confidence - - def calculate_batch_outputs( hidden_states: torch.Tensor | None, target_hidden_states: torch.Tensor | None, @@ -380,8 +192,6 @@ def calculate_batch_outputs( batch_outputs = _calculate_ground_truth_based_scores(logits, targets) - # _DEBUG_calculate_per_token_entropy(batch_outputs, logits) - if (target_hidden_states is not None) or (target_logits is not None): batch_outputs.update( _calculate_teacher_similarity_scores( @@ -399,20 +209,6 @@ def calculate_batch_outputs( return batch_outputs -def _DEBUG_calculate_per_token_entropy(batch_outputs, logits, i_batch): - import os - - # calculate the per token entropy and per token top p - entropy = calc_entropy(logits).cpu() # .view(-1)#.tolist() - msftm = confidence_max_softmax(logits).cpu() # .view(-1)#.tolist() - teacher_dir = ".../meta-llama/Meta-Llama-3.1-70B-Instruct-new_rope/" - file_path = f"{teacher_dir}/validation/per_token_stats_{i_batch}.pth" - os.makedirs(os.path.dirname(file_path), exist_ok=True) - torch.save({"entropy": entropy, "max_softmax": msftm}, file_path) - batch_outputs["entropy"] = entropy - batch_outputs["max_softmax"] = msftm - - def _organize_outputs( outputs_per_batch: list[dict], ) -> tuple[dict[str, dict], list[torch.Tensor] | None]: @@ -473,28 +269,6 @@ def _calculate_ground_truth_based_scores( return scores -def _calculate_per_sample_kl_div_loss( - logits: torch.Tensor, - batch_target_probs: torch.Tensor | LowMemorySparseTensor, -) -> list[float]: - if isinstance(batch_target_probs, LowMemorySparseTensor): - logits = top_p_top_k(logits) - curr_target_probs = batch_target_probs.to_dense().to(logits.device) # .float() - per_sample_kl_div = [ - F.kl_div( - logits[i_sample].log_softmax(-1), - curr_target_probs[i_sample], - reduction="none", - log_target=False, - ) - .sum(-1) - .mean(-1) - .item() - for i_sample in range(logits.shape[0]) - ] - return per_sample_kl_div - - def cosine_embedding_loss( hidden_states: torch.Tensor, target_hidden_states: torch.Tensor, @@ -762,49 +536,6 @@ def tv_dist( DEFAULT_TOP_K = 1000 -def calculate_sparse_probs( - logits: torch.Tensor, - top_p: float | None = DEFAULT_TOP_P, - top_k: int | None = DEFAULT_TOP_K, - verbose: bool = False, -) -> LowMemorySparseTensor: - warped_logits = top_p_top_k(logits, top_p, top_k) - probs = warped_logits.softmax(-1) - sparse_probs = LowMemorySparseTensor(probs) - if True: # Always calculate these metrics (was: if verbose or True:) - probs_unfiltered = logits.softmax(-1) - num_active_per_token = (warped_logits > -1000).sum(-1).float() - prob_density = torch.tensor( - [ - probs_unfiltered[i, j, warped_logits[i, j] > -1000].sum(-1).float() - for j in range(probs_unfiltered.shape[1]) - for i in range(probs_unfiltered.shape[0]) - ] - ) - - print(f""" - Sparsity: - {num_active_per_token.mean().item()=} - {num_active_per_token.quantile(0.25).item()=} - {num_active_per_token.quantile(0.5).item()=} - {num_active_per_token.quantile(0.75).item()=} - {num_active_per_token.quantile(0.9).item()=} - {num_active_per_token.quantile(0.95).item()=} - {num_active_per_token.max().item()=} - - {probs_unfiltered.shape=} - {prob_density.shape=} - {prob_density.mean().item()=} - {prob_density.quantile(0.25).item()=} - {prob_density.quantile(0.5).item()=} - {prob_density.quantile(0.75).item()=} - {prob_density.quantile(0.9).item()=} - {prob_density.quantile(0.95).item()=} - {prob_density.max().item()=} - """) - return sparse_probs - - def top_p_top_k( logits: torch.Tensor, top_p: float | None = DEFAULT_TOP_P, diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 033b4aadbd..9b32d1ac46 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -20,6 +20,8 @@ import os import time from collections.abc import Callable +from contextlib import suppress +from datetime import timedelta from typing import Any import torch @@ -70,11 +72,23 @@ def rank(group=None) -> int: return 0 +def local_rank() -> int: + """Returns the local rank of the current process.""" + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + raise RuntimeError("LOCAL_RANK environment variable not found.") + + def is_master(group=None) -> bool: """Returns whether the current process is the master process.""" return rank(group=group) == 0 +def is_last_process(group=None) -> bool: + """Returns whether the current process is the last process.""" + return rank(group=group) == size(group=group) - 1 + + def _serialize(obj: Any) -> torch.Tensor: buffer = io.BytesIO() torch.save(obj, buffer) @@ -184,6 +198,21 @@ def wrapper(*args, **kwargs): return wrapper +def setup(timeout: timedelta | None = None): + """Sets up the distributed environment.""" + torch.cuda.set_device(local_rank()) + if not is_initialized(): + torch.distributed.init_process_group("cpu:gloo,cuda:nccl", timeout=timeout) + + +def cleanup(): + """Cleans up the distributed environment.""" + if is_initialized(): + with suppress(Exception): + barrier() + torch.distributed.destroy_process_group() + + class DistributedProcessGroup: """A convenient wrapper around torch.distributed.ProcessGroup objects.""" diff --git a/setup.py b/setup.py index 20a271fe15..e19935a88c 100644 --- a/setup.py +++ b/setup.py @@ -111,7 +111,6 @@ "omegaconf==2.3.0", "pandas", "typeguard", - "wandb~=0.17.5", ], } diff --git a/tests/gpu/torch/_compress/compress_test_utils.py b/tests/gpu/torch/_compress/compress_test_utils.py index 9df5f5bfcf..1da08602bf 100644 --- a/tests/gpu/torch/_compress/compress_test_utils.py +++ b/tests/gpu/torch/_compress/compress_test_utils.py @@ -21,14 +21,12 @@ from datasets import Dataset, DatasetDict from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers def setup_test_model_and_data( - project_root_path: Path, - tmp_path: Path, - rank: int, - runtime, + project_root_path: Path, tmp_path: Path, rank: int ) -> tuple[Path, Path, Path]: """ Setup the test model and data for the compress NAS search. @@ -37,7 +35,6 @@ def setup_test_model_and_data( project_root_path (Path): the root path of the project tmp_path (Path): the temporary path to use for the test rank (int): the rank of the process - runtime: the runtime to use for the test Returns: tuple[Path, Path, Path]: @@ -63,7 +60,7 @@ def setup_test_model_and_data( create_and_save_small_llama_model( llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer ) - runtime.wait_for_everyone() + dist.barrier() return ( puzzle_dir, diff --git a/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py b/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py index dbbcbacd47..913bc2116c 100644 --- a/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime import os +from datetime import timedelta from functools import partial from pathlib import Path @@ -23,14 +23,10 @@ from gpu.torch._compress.compress_test_utils import setup_test_model_and_data import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.tools.runtime import NativeDdpRuntime -# -# See tests/gpu/torch/_compress/test_compress.py for instructions on how to run this test -# TODO: Remove those instructions once this test runs automatically on CI -# def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -42,51 +38,49 @@ def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): def _test_nas_convert_ffn_pruning_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, runtime - ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" - - # - # Run the mnt.convert() step - # - input_model = CompressModel() - mtn.convert( - input_model, - mode=[ - ( - "compress", - { - "puzzle_dir": str(puzzle_dir), - "input_model_path": str(llama_checkpoint_path), - "hydra_config_dir": str(hydra_config_dir), - "hydra_config_name": hydra_config_name, - "dataset_path": str(dataset_path), - }, - ) - ], - ) - - # - # Check assertions - # - if rank == 0: - # assertions for the score_pruning_activations step - rank = int(os.environ["RANK"]) - rank_filepath = ( - f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B-ffn-pruning" + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, ) - assert (puzzle_dir / rank_filepath).is_file() + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() - # assertions for the pruning_ckpts step - assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() - runtime.wait_for_everyone() + dist.cleanup() print("PYTEST SUMMARY: test_nas_convert_ffn_pruning() test has finished successfully") @@ -102,53 +96,51 @@ def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): def _test_nas_convert_attn_pruning_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, runtime - ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B-attn-pruning" - - # - # Run the mnt.convert() step - # - input_model = CompressModel() - mtn.convert( - input_model, - mode=[ - ( - "compress", - { - "puzzle_dir": str(puzzle_dir), - "input_model_path": str(llama_checkpoint_path), - "hydra_config_dir": str(hydra_config_dir), - "hydra_config_name": hydra_config_name, - "dataset_path": str(dataset_path), - }, - ) - ], - ) - - # - # Check assertions - # - if rank == 0: - # assertions for the score_pruning_activations step - rank = int(os.environ["RANK"]) - rank_filepath = ( - f"pruning/pruning_scores/attn_independent_kv_head_contribution/" - f"100samples_diverse_mini/rank_{rank}.pth" + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B-attn-pruning" + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, ) - assert (puzzle_dir / rank_filepath).is_file() + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/attn_independent_kv_head_contribution/" + f"100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() - # assertions for the pruning_ckpts step - assert (puzzle_dir / "ckpts/n_heads_in_group8").exists() - assert (puzzle_dir / "ckpts/n_heads_in_group16").exists() - assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/n_heads_in_group8").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group16").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() - runtime.wait_for_everyone() + dist.cleanup() print("PYTEST SUMMARY: test_nas_convert_attn_pruning() test has finished successfully") diff --git a/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py b/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py index e8ea24ecee..1b4ed93c66 100644 --- a/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py @@ -13,11 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# -# See tests/gpu/torch/_compress/test_compress.py for instructions on how to run this test -# TODO: Remove those instructions once this test runs automatically on CI -# -import datetime +from datetime import timedelta from functools import partial from pathlib import Path @@ -26,8 +22,8 @@ from gpu.torch._compress.compress_test_utils import setup_test_model_and_data import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.tools.runtime import NativeDdpRuntime def test_nas_search(project_root_path: Path, tmp_path: Path): @@ -41,72 +37,68 @@ def test_nas_search(project_root_path: Path, tmp_path: Path): def _test_nas_search_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, runtime - ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" - - # - # Run the mnt.convert() step - # - input_model = CompressModel() - converted_model = mtn.convert( - input_model, - mode=[ - ( - "compress", - { - "puzzle_dir": str(puzzle_dir), - "input_model_path": str(llama_checkpoint_path), - "hydra_config_dir": str(hydra_config_dir), - "hydra_config_name": hydra_config_name, - "dataset_path": str(dataset_path), - }, - ) - ], - ) + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B-ffn-pruning" + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) - # - # Run the mnt.search() step - # - mtn.search( - converted_model, - constraints={}, # this is not used as the search space is defined in the hydra config - dummy_input=None, # Not used - config={}, # this is not used as the search space is defined in the hydra config - ) + # + # Run the mnt.search() step + # + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) - # - # Check assertions for mtn.search() step - # - if rank == 0: - # assertions for the build_library_and_stats step - assert (puzzle_dir / "replacement_library.json").is_file() - assert (puzzle_dir / "subblock_stats.json").is_file() - - # assertions for the scoring step - solution_0_filepath = ( - puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" - ) + # + # Check assertions for mtn.search() step + # + if rank == 0: + # assertions for the build_library_and_stats step + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() + + # assertions for the scoring step + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) - assert solution_0_filepath.exists() + assert solution_0_filepath.exists() - # assertions for the mip_and_realize_models step - solution_0_ckpt_config_path = ( - puzzle_dir - / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" - ) + # assertions for the mip_and_realize_models step + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) - assert solution_0_ckpt_config_path.exists() - assert ( - puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json" - ).exists() + assert solution_0_ckpt_config_path.exists() + assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() - runtime.wait_for_everyone() + dist.cleanup() print("PYTEST SUMMARY: test_nas_search() test has finished successfully") diff --git a/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml b/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml index 178edb50d8..192b82c75e 100644 --- a/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml +++ b/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml @@ -1,3 +1,5 @@ +model_dtype: torch.bfloat16 +autocast_dtype: torch.bfloat16 block_size: 8192 bos_rate: 0.5 data_column: conversation diff --git a/tests/gpu/torch/_compress/test_compress.py b/tests/gpu/torch/_compress/test_compress.py index e40756602a..997bb99719 100644 --- a/tests/gpu/torch/_compress/test_compress.py +++ b/tests/gpu/torch/_compress/test_compress.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime import os +from datetime import timedelta from functools import partial from pathlib import Path @@ -22,11 +22,11 @@ from _test_utils.torch.distributed.utils import spawn_multiprocess_job from gpu.torch._compress.compress_test_utils import setup_test_model_and_data +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress import compress from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.tools.runtime import NativeDdpRuntime # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) # using a one-click command. @@ -43,66 +43,60 @@ def test_compress(project_root_path: Path, tmp_path: Path): def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, runtime + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B-ffn-pruning" + + # Convert the Llama model to DeciLM model. + if rank == 0: + convert_llama3_to_decilm( + input_dir=llama_checkpoint_path, + output_dir=puzzle_dir / "ckpts/teacher", ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" - - # Convert the Llama model to DeciLM model. - if rank == 0: - convert_llama3_to_decilm( - input_dir=llama_checkpoint_path, - output_dir=puzzle_dir / "ckpts/teacher", - ) - runtime.wait_for_everyone() - - # Compress the model using a one-click approach - compress.compress( - str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path), runtime + dist.barrier() + + # Compress the model using a one-click approach + compress.compress(str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path)) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step 1 + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" ) + assert (puzzle_dir / rank_filepath).is_file() - # - # Check assertions - # - if rank == 0: - # assertions for the score_pruning_activations step 1 - rank = int(os.environ["RANK"]) - rank_filepath = ( - f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" - ) - assert (puzzle_dir / rank_filepath).is_file() - - # assertions for the pruning_ckpts step 2 - assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() - # assertions for the build_library_and_stats step 4 + # assertions for the build_library_and_stats step 4 - assert (puzzle_dir / "replacement_library.json").is_file() - assert (puzzle_dir / "subblock_stats.json").is_file() + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() - # assertions for the scoring step 5 - solution_0_filepath = ( - puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" - ) + # assertions for the scoring step 5 + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) - assert solution_0_filepath.exists() + assert solution_0_filepath.exists() - # assertions for the mip_and_realize_models step 6 - solution_0_ckpt_config_path = ( - puzzle_dir - / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" - ) + # assertions for the mip_and_realize_models step 6 + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) - assert solution_0_ckpt_config_path.exists() - assert ( - puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json" - ).exists() + assert solution_0_ckpt_config_path.exists() + assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() - runtime.wait_for_everyone() + dist.cleanup() print("PYTEST SUMMARY: test_compress_model() test has finished successfully") From f7a0cb08e8ba458f1b0b759730856252757f9363 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 17 Dec 2025 08:50:29 +0100 Subject: [PATCH 26/62] Use shared activation hooks component in the puzzle algorithm (#687) ## What does this PR do? - Use shared activation hooks component in the puzzle algorithm. - Improve test_compress test by checking if activation scoring works correctly. ## Additional Information In the next step, torch.nas.plugins.megatron_hooks should be renamed and possibly moved to other place, e.g., into torch.pruning.activation_scoring. --------- Signed-off-by: Daniel Korzekwa --- .../activation_hooks/utils.py | 20 +- .../_compress/utils/checkpoint_manager.py | 12 +- .../nas/plugins/megatron_hooks/__init__.py | 3 +- .../plugins/megatron_hooks/base_hooks.py} | 505 ++++++++++++++---- ...oks_analysis.py => base_hooks_analysis.py} | 4 +- .../plugins/megatron_hooks/megatron_hooks.py | 461 +--------------- tests/gpu/torch/_compress/test_compress.py | 34 +- ...t_megatron_hooks.py => test_base_hooks.py} | 31 +- ...nalysis.py => test_base_hooks_analysis.py} | 60 +-- 9 files changed, 459 insertions(+), 671 deletions(-) rename modelopt/torch/{_compress/activation_scoring/activation_hooks/hooks.py => nas/plugins/megatron_hooks/base_hooks.py} (53%) rename modelopt/torch/nas/plugins/megatron_hooks/{megatron_hooks_analysis.py => base_hooks_analysis.py} (97%) rename tests/gpu/torch/nas/plugins/megatron_hooks/{test_megatron_hooks.py => test_base_hooks.py} (78%) rename tests/gpu/torch/nas/plugins/megatron_hooks/{test_megatron_hooks_analysis.py => test_base_hooks_analysis.py} (75%) diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py index 457e37d74e..931ac762f5 100644 --- a/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py @@ -19,28 +19,34 @@ import re -from modelopt.torch._compress.activation_scoring.activation_hooks import hooks from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( + ForwardHook, + IndependentChannelContributionHook, + IndependentKvHeadContributionHook, + IterativeChannelContributionHook, + LayerNormContributionHook, +) def register_activation_hooks( model: DeciLMForCausalLM, activation_hooks_kwargs: dict -) -> tuple[dict[str, hooks.ActivationsHook], hooks.ActivationsHook]: +) -> tuple[dict[str, ForwardHook], type[ForwardHook]]: hook_class_map = { "mlp.down_proj": { - "independent": hooks.IndependentChannelContributionHook, - "iterative": hooks.IterativeChannelContributionHook, + "independent": IndependentChannelContributionHook, + "iterative": IterativeChannelContributionHook, }, "self_attn.o_proj": { - "independent_kv_head_contribution": hooks.IndependentKvHeadContributionHook, + "independent_kv_head_contribution": IndependentKvHeadContributionHook, }, r"regex:experts\.\d+\.down_proj$": { # For MoE - "independent": hooks.IndependentChannelContributionHook, + "independent": IndependentChannelContributionHook, }, # TODO: maybe this is too generic, and we should have it specifically for # input_layernorm and post_attention_layernorm; now it might select qk_norms "layernorm": { - "layer_norm_contribution": hooks.LayerNormContributionHook, + "layer_norm_contribution": LayerNormContributionHook, }, } diff --git a/modelopt/torch/_compress/utils/checkpoint_manager.py b/modelopt/torch/_compress/utils/checkpoint_manager.py index 7a27334469..b43c37481d 100644 --- a/modelopt/torch/_compress/utils/checkpoint_manager.py +++ b/modelopt/torch/_compress/utils/checkpoint_manager.py @@ -193,11 +193,9 @@ def update_progress(self, batch_idx: int, total_batches: int): # All ranks save their hook states if self.activation_hooks is not None: try: - from modelopt.torch._compress.activation_scoring.activation_hooks.hooks import ( - ActivationsHook, - ) + from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook - ActivationsHook.save_hook_states(self.activation_hooks, self.checkpoint_dir) + ForwardHook.save_hook_states(self.activation_hooks, self.checkpoint_dir) except Exception as e: mprint(f"Warning: Failed to save hook states: {e}") @@ -249,11 +247,9 @@ def finalize(self): # All ranks save their final hook states if self.activation_hooks is not None: try: - from modelopt.torch._compress.activation_scoring.activation_hooks.hooks import ( - ActivationsHook, - ) + from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook - saved_path = ActivationsHook.save_hook_states( + saved_path = ForwardHook.save_hook_states( self.activation_hooks, self.checkpoint_dir ) mprint(f"Final hook states saved to {saved_path}") diff --git a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py index 1d19308edf..0ba4405183 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py @@ -14,5 +14,6 @@ # limitations under the License. """Forward hooks for estimating importance scores for pruning.""" +from .base_hooks import * +from .base_hooks_analysis import * from .megatron_hooks import * -from .megatron_hooks_analysis import * diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py similarity index 53% rename from modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py rename to modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py index 510f691111..bfc9b9290b 100644 --- a/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -12,11 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# mypy: ignore-errors - -"""Provides hooks for capturing the inputs and the outputs of pytorch modules that are used for -activation scoring for pruning. -""" +"""Forward hooks for activation-based importance estimation.""" import gc import json @@ -30,80 +26,120 @@ from torch import nn import modelopt.torch.utils.distributed as dist - -# BlockConfig used at runtime, not just type hints (lines 680, 790) -from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import ( - DeciLMConfig, # noqa: TC001 -) -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMRMSNorm from modelopt.torch._compress.tools.logger import aprint from modelopt.torch._compress.tools.robust_json import json_dump +__all__ = [ + "ForwardHook", + "IndependentChannelContributionHook", + "IndependentKvHeadContributionHook", + "IterativeChannelContributionHook", + "L2NormHook", + "LayerNormContributionHook", +] + def clear_gpu_memory(clear: bool) -> None: + """Clear GPU memory cache if requested. + + Args: + clear: If True, runs garbage collection and empties CUDA cache. + """ if clear: gc.collect() torch.cuda.empty_cache() -class ActivationsHook(ABC): +class ForwardHook(ABC): + """Base class for PyTorch forward hooks. + + This follows the PyTorch forward hook API where the second + parameter is 'args' (a tuple of positional arguments passed to forward()). + + Usage: + hook = MyHook() + module.register_forward_hook(hook) + """ + @abstractmethod - def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: - """ - A hook to be registered in pytorch modules: torch.nn.Module.register_forward_hook() + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that is called after the module's forward pass. Args: - module (nn.Module): - args (tuple[torch.Tensor]): Input of the pytorch module - output (torch.Tensor): Output of the pytorch module + module: The module this hook is registered on + args: Tuple of positional arguments passed to module.forward() + output: The output from module.forward() + + Returns: + None (does not modify the output) """ ... @abstractmethod - def to_dict(self) -> dict[str, torch.Tensor]: ... + def accumulate(self) -> torch.Tensor: + """Return accumulated importance scores. - def save_state(self) -> dict: - """ - Save the internal state of the hook for checkpointing. + This method should be called after all forward passes to retrieve + the final importance scores for each channel/feature. Returns: - dict: State dictionary that can be used to restore the hook's state + Tensor of importance scores, one per channel/feature. + + Raises: + AssertionError: If no activations have been collected yet. """ - # Default implementation - hooks should override this if they have state to save - return {} + ... + + @abstractmethod + def state_dict(self) -> dict: + """Return the internal state for checkpointing. - def load_state(self, state_dict: dict) -> None: + Returns: + dict: State dictionary containing checkpoint data. + Can contain tensors, ints, lists, etc. """ - Load the internal state of the hook from a checkpoint. + ... + + @abstractmethod + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint. Args: - state_dict: State dictionary previously returned by save_state() + state_dict: State dictionary previously returned by state_dict() """ - # Default implementation - hooks should override this if they have state to load + ... def get_progress_info(self) -> dict: - """ - Get progress information for this hook (e.g., current iteration, samples processed). + """Get progress information for this hook. Returns: - dict: Progress information + dict: Progress information (e.g., current iteration, samples processed). + Default implementation returns empty dict. """ - # Default implementation - hooks can override to provide progress info return {} + @abstractmethod + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert hook results to dictionary format for saving. + + Returns: + dict: Dictionary containing result tensors (e.g., "score", "channels_importance_ascending"). + """ + ... + @classmethod def dump_activations_logs( - cls: type["ActivationsHook"], - activation_hooks: dict[str, "ActivationsHook"], + cls: type["ForwardHook"], + activation_hooks: dict[str, "ForwardHook"], activations_log_dir: Path | str, args: DictConfig, ) -> None: - """ - Default implementation for dumping final activation scores logs to disk. + """Default implementation for dumping final activation scores logs to disk. + This is called only at the end of scoring to save final results. """ - activations_log_dir = Path(activations_log_dir) activations_log_dir.mkdir(exist_ok=True, parents=True) rank = dist.rank() @@ -122,12 +158,12 @@ def dump_activations_logs( @classmethod def save_hook_states( - cls: type["ActivationsHook"], - activation_hooks: dict[str, "ActivationsHook"], + cls: type["ForwardHook"], + activation_hooks: dict[str, "ForwardHook"], activations_log_dir: Path | str, ) -> None: - """ - Save hook states for checkpointing (separate from final results). + """Save hook states for checkpointing (separate from final results). + This can be called periodically during scoring. Note: Synchronization should be handled at a higher level to avoid deadlocks. """ @@ -137,49 +173,179 @@ def save_hook_states( hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth" hook_states = { - module_name: hook.save_state() for module_name, hook in activation_hooks.items() + module_name: hook.state_dict() for module_name, hook in activation_hooks.items() } torch.save(hook_states, hook_states_path) - return hook_states_path +class L2NormHook(ForwardHook): + """Hook for accumulating activation statistics for importance estimation. + + Activations are computed as mean over seq_len and then squared and summed over batch_size. + In the accumulate() method we take the square root of the sum to get the L2 norm. + + This is the base version without tensor parallelism support. + For megatron with TP > 1, use MegatronL2NormHook instead. + + Args: + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__(self, max_size: int | None = None): + """Initialize the L2NormHook.""" + self.max_size = max_size + self._activations: torch.Tensor | None = None + + def _get_input_tensor(self, args: tuple[torch.Tensor, ...]) -> torch.Tensor: + """Get input tensor from args. Override in subclass for TP gathering.""" + return args[0].detach() + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Accumulate activation statistics from the forward pass. + + Args: + module: The module this hook is registered on. + args: Tuple of input tensors. args[0] expected shape: [seq_len, batch_size, hidden_size] + (Megatron sequence-first format). + output: Output tensor from the module's forward pass. + """ + input_tensor = self._get_input_tensor(args) + + if input_tensor.dim() == 2: + # For sparse experts, there is no batch dimension. + input_tensor = input_tensor[:, None, :] + + # Dont aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and input_tensor.shape[-1] != self.max_size: + return + + input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow + activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size] + activations = activations.pow(2).sum(dim=0) # [hidden_size] + + if self._activations is None: + self._activations = activations + else: + self._activations += activations + + def accumulate(self) -> torch.Tensor: + """Return the accumulated L2 norm of activations. + + Returns: + Tensor of accumulated scores, one per channel + + Raises: + AssertionError: If no activations have been collected yet + """ + assert self._activations is not None, "No activations collected for importance estimation." + # Convert squared sum to L2 norm + return self._activations.pow(0.5) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert to dict format for saving.""" + return {"score": self.accumulate().cpu()} + + def state_dict(self) -> dict: + """Return the state dictionary containing activations.""" + return {"activations": self._activations} + + def load_state_dict(self, state_dict: dict) -> None: + """Load activations from checkpoint.""" + self._activations = state_dict["activations"] + + +class IndependentChannelContributionHook(ForwardHook): + """Hook for channel importance estimation using weight norms and activation magnitudes. + + Computes channel importance as the product of: + - L2 norm of each column in the weight matrix (how much each input channel affects output) + - Mean absolute activation for each channel (how strongly each channel is activated) + + Args: + linear_layer: The linear projection layer to analyze. Must have a `weight` attribute + and either `in_features` (nn.Linear) or `input_size` (Megatron RowParallelLinear). + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__( + self, + linear_layer: nn.Module, + max_size: int | None = None, + ): + """Initialize the independent channel contribution hook.""" + self.max_size = max_size -class IndependentChannelContributionHook(ActivationsHook): - def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): weight_matrix = linear_layer.weight.float() self.weight_norm = torch.linalg.vector_norm(weight_matrix, dim=0) - num_channels = linear_layer.in_features + + # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) + if hasattr(linear_layer, "input_size"): + self.num_channels = linear_layer.input_size # Megatron-Core + else: + self.num_channels = linear_layer.in_features # PyTorch + self.agg_channel_activations = torch.zeros( - size=(num_channels,), dtype=torch.float32, device=weight_matrix.device + size=(self.num_channels,), + dtype=torch.float32, + device=weight_matrix.device, ) - def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: - """ - :param module: - :param args: tuple with one tensor entry (B,T,I) - :param output: B,T,E + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple + ) -> None: + """Accumulate mean absolute activations per channel. + + Args: + module: The module this hook is registered on. + args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] + (PyTorch batch-first format). + output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) + for parallel layers. """ activations = args[0] + + # Don't aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and activations.shape[-1] != self.max_size: + return + mean_abs_channel_activations = ( activations.abs().float().mean(dim=list(range(activations.ndim - 1))) ) - self.agg_channel_activations[:] += mean_abs_channel_activations # shape [I] + self.agg_channel_activations[:] += mean_abs_channel_activations # shape [input_channels] def to_dict(self) -> dict[str, torch.Tensor]: + """Convert results to dict with channel importance scores. + + Returns: + Dict with "score" (weight_norm * activations), "weight_norm", and + "agg_channel_activations". + """ return { "score": (self.weight_norm * self.agg_channel_activations).cpu(), "weight_norm": self.weight_norm.cpu(), "agg_channel_activations": self.agg_channel_activations.cpu(), } - def save_state(self) -> dict: + def accumulate(self) -> torch.Tensor: + """Return importance scores as a tensor. + + Returns: + Tensor of importance scores (weight_norm * activations), one per channel. + """ + return self.to_dict()["score"] + + def state_dict(self) -> dict: """Save the internal state for checkpointing.""" return { "agg_channel_activations": self.agg_channel_activations.cpu().clone(), "weight_norm": self.weight_norm.cpu().clone(), } - def load_state(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: """Load the internal state from a checkpoint.""" self.agg_channel_activations = state_dict["agg_channel_activations"].to( self.agg_channel_activations.device @@ -188,14 +354,14 @@ def load_state(self, state_dict: dict) -> None: # but we can verify it matches expected_weight_norm = state_dict["weight_norm"].to(self.weight_norm.device) if not torch.allclose(self.weight_norm, expected_weight_norm, rtol=1e-5): - print( - "Warning: weight_norm mismatch during state loading - model weights may have changed" + raise AssertionError( + "weight_norm mismatch during state loading - model weights may have changed" ) def get_pruning_schedule(num_channels, pruning_iters): - """ - Spending decreases monotonically when num_channels >= pruning_iters. + """Spending decreases monotonically when num_channels >= pruning_iters. + Intervals between spends increase monotonically when pruning_iters > num_channels. The budget is fully utilized, and there's spending in the last iteration. num_channels = 10, pruning_iters = 4 ==> [3, 3, 2, 2] @@ -223,16 +389,40 @@ def get_pruning_schedule(num_channels, pruning_iters): return schedule -class IterativeChannelContributionHook(ActivationsHook): - def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): - """TODO: Add docstring. +class IterativeChannelContributionHook(ForwardHook): + """Hook for iterative channel pruning based on contribution analysis. - Args: - linear_layer: The linear projection layer - activation_hooks_kwargs: The activation hooks kwargs - """ + Progressively identifies and removes the least important input channels of a linear layer + by measuring channel contribution as the L2 norm of output change when removed. + + Args: + linear_layer: The linear projection layer to analyze. Must have a `weight` attribute + and either `in_features` (nn.Linear) or `input_size` (Megatron RowParallelLinear). + activation_hooks_kwargs: Configuration dict with: + - validation_full_iters (int): Number of pruning iterations. + - clear_gpu_memory (bool, optional): Clear GPU memory during computation. + - calibration_method (str, optional): "scale_by_magnitude" or None. + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__( + self, + linear_layer: nn.Module, + activation_hooks_kwargs: dict, + max_size: int | None = None, + ): + """Initialize the iterative channel contribution hook.""" self.weight_matrix = linear_layer.weight - self.num_channels = linear_layer.in_features + + # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) + # TODO: Consider better design to handle RowParallelLinear and nn.Linear + if hasattr(linear_layer, "input_size"): + self.num_channels = linear_layer.input_size # Megatron-Core + else: + self.num_channels = linear_layer.in_features # PyTorch + + self.max_size = max_size self.pruning_iters = activation_hooks_kwargs["validation_full_iters"] self.clear_gpu_memory = activation_hooks_kwargs.get("clear_gpu_memory", False) self.curr_iter = 0 @@ -249,13 +439,31 @@ def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): self.calibration_method = activation_hooks_kwargs.get("calibration_method") self.epsilon = 1e-8 - def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: - """ - :param module: - :param args: tuple with one tensor entry (B,T,I) - :param output: B,T,E + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple + ) -> None: + """Compute channel contributions and prune channels according to schedule. + + Args: + module: The module this hook is registered on. + args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] + (PyTorch batch-first format). + output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) + for parallel layers. """ + # Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear) + # TODO: Consider better design to handle RowParallelLinear and nn.Linear + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + activations = args[0] + + # Don't aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and activations.shape[-1] != self.max_size: + return + n_channels_to_prune = self.pruning_schedule[self.curr_iter] curr_activations = activations.clone() # Shape B,T,I @@ -263,9 +471,9 @@ def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.T output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E if self.calibration_method is None: - scaling_factor_per_token = torch.ones_like(output[..., 0]) # Shape B,T + scaling_factor_per_token = torch.ones_like(output_tensor[..., 0]) # Shape B,T elif self.calibration_method == "scale_by_magnitude": - output_norms = torch.linalg.vector_norm(output, dim=-1) # Shape B,T + output_norms = torch.linalg.vector_norm(output_tensor, dim=-1) # Shape B,T output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon) del output_curr_norms, output_norms @@ -274,7 +482,7 @@ def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.T del curr_activations clear_gpu_memory(clear=self.clear_gpu_memory) - s = scaling_factor_per_token.unsqueeze(-1) * output - output_curr # Shape: (B, T, E) + s = scaling_factor_per_token.unsqueeze(-1) * output_tensor - output_curr # Shape: (B, T, E) s_squared_per_token = torch.sum(s**2, dim=-1) # Shape: (B, T) b = s @ self.weight_matrix # Shape: (B, T, I) c = torch.sum(self.weight_matrix**2, dim=0) # Shape: (I) @@ -293,10 +501,11 @@ def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.T del contribution, contribution_squared clear_gpu_memory(clear=self.clear_gpu_memory) - if n_channels_to_prune == 0: - self.agg_cont_per_channel += mean_cont_per_channel - else: - _, worst_indices = torch.topk(mean_cont_per_channel, n_channels_to_prune, largest=False) + self.agg_cont_per_channel += mean_cont_per_channel + if n_channels_to_prune > 0: + _, worst_indices = torch.topk( + self.agg_cont_per_channel, n_channels_to_prune, largest=False + ) worst_indices_list = worst_indices.tolist() assert not set(self.pruned_channels).intersection(set(worst_indices_list)) self.pruned_channels.extend(worst_indices_list) @@ -304,6 +513,12 @@ def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.T self.curr_iter += 1 def to_dict(self) -> dict[str, torch.Tensor]: + """Convert pruning results to dict with channel importance rankings. + + Returns: + Dict with "score" (importance rank per channel) and + "channels_importance_ascending" (channel indices in ascending importance). + """ assert self.num_channels == len(self.pruned_channels) channels_importance_ascending = torch.tensor(self.pruned_channels, dtype=torch.long) score = torch.empty(self.num_channels, dtype=torch.long) @@ -314,7 +529,15 @@ def to_dict(self) -> dict[str, torch.Tensor]: "channels_importance_ascending": channels_importance_ascending.cpu(), } - def save_state(self) -> dict: + def accumulate(self) -> torch.Tensor: + """Return importance scores as a tensor. + + Returns: + Tensor of importance scores, one per channel. Lower scores indicate less important channels. + """ + return self.to_dict()["score"] + + def state_dict(self) -> dict: """Save the internal state for checkpointing.""" return { "curr_iter": self.curr_iter, @@ -327,7 +550,7 @@ def save_state(self) -> dict: "epsilon": self.epsilon, } - def load_state(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: """Load the internal state from a checkpoint.""" self.curr_iter = state_dict["curr_iter"] self.pruned_channels = state_dict["pruned_channels"].copy() @@ -338,7 +561,11 @@ def load_state(self, state_dict: dict) -> None: assert self.pruning_schedule == state_dict["pruning_schedule"], "Pruning schedule mismatch" def get_progress_info(self) -> dict: - """Get progress information.""" + """Get progress information for this hook. + + Returns: + dict: Progress information including iteration count and pruned channels. + """ progress = self.curr_iter / self.pruning_iters if self.pruning_iters > 0 else 0.0 return { "curr_iter": self.curr_iter, @@ -349,16 +576,24 @@ def get_progress_info(self) -> dict: } -class IndependentKvHeadContributionHook(ActivationsHook): - def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): - """TODO: Add docstring. +class IndependentKvHeadContributionHook(ForwardHook): + """Hook for estimating KV head importance based on contribution analysis. - Args: - linear_layer: The linear projection layer - activation_hooks_kwargs: The activation hooks kwargs - """ - model_config: DeciLMConfig = activation_hooks_kwargs["model"].config - block_config: BlockConfig = activation_hooks_kwargs["block_config"] + Measures the contribution of each KV head group to the output projection + by computing L2 norms of per-head outputs. + + Args: + linear_layer: The output projection layer (o_proj). + activation_hooks_kwargs: Configuration dict with: + - model: The model instance (to get config). + - block_config: Block configuration with attention settings. + - optimize_for (str, optional): "latency" or "memory". Defaults to "memory". + """ + + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + """Initialize the KV head contribution hook.""" + model_config = activation_hooks_kwargs["model"].config + block_config = activation_hooks_kwargs["block_config"] self.optimize_for = activation_hooks_kwargs.get("optimize_for", "memory") assert self.optimize_for in ["latency", "memory"] @@ -382,11 +617,7 @@ def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): # weight_grouped.shape: (kv_heads, hidden_dim, head_dim * n_heads_in_group) def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: - """ - :param module: The linear projection layer - :param args: tuple containing attention output tensor (B, T, num_q_heads * head_dim) - :param output: The projected output (B, T, hidden_dim) - """ + """Compute KV head contributions from the forward pass.""" attn_out = args[0] # Shape: (B, T, num_q_heads * head_dim) batch_size, seq_len, _ = attn_out.shape @@ -423,14 +654,44 @@ def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.T # Accumulate contributions self.agg_kv_head_contributions += contrib_per_kv_head + def accumulate(self) -> torch.Tensor: + """Return accumulated KV head importance scores. + + Returns: + Tensor of importance scores, one per KV head. + """ + return self.agg_kv_head_contributions + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert to dict format for saving. + + Returns: + Dict with "score" tensor containing KV head importance scores. + """ return { "score": self.agg_kv_head_contributions.cpu(), } + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + raise NotImplementedError("Saving state dict is not supported for this hook.") + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + raise NotImplementedError("Loading state dict is not supported for this hook.") + + +class LayerNormContributionHook(ForwardHook): + """Hook for estimating channel importance based on layer normalization activations. -class LayerNormContributionHook(ActivationsHook): - def __init__(self, layernorm_layer: DeciLMRMSNorm, activation_hooks_kwargs: dict): + Aggregates mean absolute activation values per channel for a layer normalization layer. + + Args: + layernorm_layer: The layer normalization layer. + activation_hooks_kwargs: The activation hooks kwargs (not used). + """ + + def __init__(self, layernorm_layer: nn.Module, activation_hooks_kwargs: dict): """Aggregates mean absolute activation values per channel for a layer normalization layer. Args: @@ -444,22 +705,41 @@ def __init__(self, layernorm_layer: DeciLMRMSNorm, activation_hooks_kwargs: dict ) def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """Accumulate activation statistics from the forward pass.""" self.agg_embedding_activations += ( output.abs().float().mean(dim=list(range(output.ndim - 1))) ) + def accumulate(self) -> torch.Tensor: + """Return accumulated channel importance scores.""" + return self.agg_embedding_activations + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert to dict format for saving.""" + return { + "score": self.agg_embedding_activations.cpu(), + "channels_importance_ascending": self.agg_embedding_activations.sort()[1].cpu(), + } + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + raise NotImplementedError("Saving state dict is not supported for this hook.") + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + raise NotImplementedError("Loading state dict is not supported for this hook.") + @classmethod def dump_activations_logs( cls: type["LayerNormContributionHook"], - activation_hooks: dict[str, "ActivationsHook"], + activation_hooks: dict[str, "ForwardHook"], activations_log_dir: Path | str, args: DictConfig, ) -> None: - """ - At the end of the default implementation of dumping activation scores to disc, - save aggregated channel importance results. - """ + """At the end of the default implementation of dumping activation scores to disc. + Save aggregated channel importance results. + """ super().dump_activations_logs(activation_hooks, activations_log_dir, args) rank = dist.rank() @@ -472,14 +752,11 @@ def dump_activations_logs( @staticmethod def _save_channel_importance_results( - activation_hooks: dict[str, ActivationsHook], - activations_log_dir: Path, + activation_hooks: dict[str, "ForwardHook"], + activations_log_dir: Path | str, args: DictConfig, ) -> None: - """ - Save channel importance results from activation hooks. - """ - + """Save channel importance results from activation hooks.""" # Find all activation files (for multi-rank scenarios) activations_log_dir = Path(activations_log_dir) activation_files = list(activations_log_dir.glob("rank_*.pth")) @@ -545,9 +822,3 @@ def _save_channel_importance_results( aprint(f"Score range: {avg_scores.min():.4f} to {avg_scores.max():.4f}") aprint(f"Score mean: {avg_scores.mean():.4f}") aprint(f"Score std: {avg_scores.std():.4f}") - - def to_dict(self) -> dict[str, torch.Tensor]: - return { - "score": self.agg_embedding_activations.cpu(), - "channels_importance_ascending": self.agg_embedding_activations.sort()[1].cpu(), - } diff --git a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks_analysis.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py similarity index 97% rename from modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks_analysis.py rename to modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py index caf5eed898..dc338a7cfa 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks_analysis.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Analysis tools for evaluating importance scores from megatron hooks.""" +"""Analysis tools for evaluating importance scores from hooks.""" import torch import torch.nn.functional as F @@ -36,7 +36,7 @@ def evaluate_importance_scores( linear_layer: The linear layer to analyze with shape (out_features, in_features). For example: nn.Linear(in_features=1024, out_features=4096) activations_batches: List of input activation tensors. - Each tensor has shape [seq_len, batch_size, in_features] (Megatron format). + Each tensor has shape [seq_len, batch_size, in_features]. The last dimension must match linear_layer.in_features. Example: List of [16, 8, 1024] tensors importance_scores: Importance score for each input channel (feature). diff --git a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py index 3bb1493950..d792ff8941 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py @@ -12,466 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Forward hooks for activation-based importance estimation (megatron NAS plugin).""" - -import gc -from abc import ABC, abstractmethod +"""Megatron-specific hooks with tensor parallelism support.""" import torch -import torch.nn.functional as F from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region -from megatron.core.tensor_parallel.layers import RowParallelLinear -from torch import nn - -__all__ = [ - "IndependentChannelContributionHook", - "IterativeChannelContributionHook", - "MegatronL2NormHook", -] - - -def clear_gpu_memory(clear: bool) -> None: - """Clear GPU memory cache if requested. - - Args: - clear: If True, runs garbage collection and empties CUDA cache. - """ - if clear: - gc.collect() - torch.cuda.empty_cache() - - -class ForwardHook(ABC): - """Base class for PyTorch forward hooks. - - This follows the PyTorch forward hook API where the second - parameter is 'args' (a tuple of positional arguments passed to forward()). - - Usage: - hook = MyHook() - module.register_forward_hook(hook) - """ - - @abstractmethod - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor - ) -> None: - """Forward hook that is called after the module's forward pass. - - Args: - module: The module this hook is registered on - args: Tuple of positional arguments passed to module.forward() - output: The output from module.forward() - - Returns: - None (does not modify the output) - """ - ... - - @abstractmethod - def accumulate(self) -> torch.Tensor: - """Return accumulated importance scores. - - This method should be called after all forward passes to retrieve - the final importance scores for each channel/feature. - - Returns: - Tensor of importance scores, one per channel/feature. - - Raises: - AssertionError: If no activations have been collected yet. - """ - ... - - @abstractmethod - def state_dict(self) -> dict: - """Return the internal state for checkpointing. - - Returns: - dict: State dictionary containing checkpoint data. - Can contain tensors, ints, lists, etc. - """ - ... - - @abstractmethod - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint. - Args: - state_dict: State dictionary previously returned by state_dict() - """ - ... +from .base_hooks import L2NormHook +__all__ = ["MegatronL2NormHook"] -class MegatronL2NormHook(ForwardHook): - """Hook for accumulating activation statistics for importance estimation. - Activations are computed as mean over seq_len and then squared and summed over batch_size. - In the accumulate() method we take the square root of the sum to get the L2 norm. +class MegatronL2NormHook(L2NormHook): + """L2NormHook with tensor parallelism support for Megatron models. - Args: - max_size: Optional maximum expected size to validate against (skips if mismatch). - Useful for skipping non-max subnets during profiling. + Extends L2NormHook to gather activations across all tensor parallel regions + before computing importance scores. """ - def __init__(self, max_size: int | None = None): - """Initialize the L2NormHook.""" - self.max_size = max_size - self._activations: torch.Tensor | None = None - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor - ) -> None: - """Accumulate activation statistics from the forward pass. - - Args: - module: The module this hook is registered on. - args: Tuple of input tensors. args[0] expected shape: [seq_len, batch_size, hidden_size] - (Megatron sequence-first format). - output: Output tensor from the module's forward pass. - """ + def _get_input_tensor(self, args: tuple[torch.Tensor, ...]) -> torch.Tensor: + """Gather input tensor from all TP regions.""" # Gather input [seq_len, batch_size, hidden_size] over all TP regions # NOTE: This is not used at the moment since we restrict to TP=1 - input_tensor = gather_from_tensor_model_parallel_region(args[0]).detach() - - if input_tensor.dim() == 2: - # For sparse experts, there is no batch dimension. - input_tensor = input_tensor[:, None, :] - - # Dont aggregate activations from non-max subnets (e.g. from profiling) - if self.max_size is not None and input_tensor.shape[-1] != self.max_size: - return - - input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow - activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size] - activations = activations.pow(2).sum(dim=0) # [hidden_size] - - if self._activations is None: - self._activations = activations - else: - self._activations += activations - - def accumulate(self) -> torch.Tensor: - """Return the accumulated L2 norm of activations. - - Returns: - Tensor of accumulated scores, one per channel - - Raises: - AssertionError: If no activations have been collected yet - """ - assert self._activations is not None, "No activations collected for importance estimation." - # Convert squared sum to L2 norm - return self._activations.pow(0.5) - - def state_dict(self) -> dict: - """Return the state dictionary containing activations.""" - return {"activations": self._activations} - - def load_state_dict(self, state_dict: dict) -> None: - """Load activations from checkpoint.""" - self._activations = state_dict["activations"] - - -class IndependentChannelContributionHook(ForwardHook): - """Hook for channel importance estimation using weight norms and activation magnitudes. - - Computes channel importance as the product of: - - L2 norm of each column in the weight matrix (how much each input channel affects output) - - Mean absolute activation for each channel (how strongly each channel is activated) - - Args: - linear_layer: The linear projection layer to analyze. Can be either nn.Linear or - RowParallelLinear from megatron.core.tensor_parallel.layers. - max_size: Optional maximum expected size to validate against (skips if mismatch). - Useful for skipping non-max subnets during profiling. - """ - - def __init__( - self, - linear_layer: nn.Linear | RowParallelLinear, - max_size: int | None = None, - ): - """Initialize the independent channel contribution hook.""" - self.max_size = max_size - - weight_matrix = linear_layer.weight.float() - self.weight_norm = torch.linalg.vector_norm(weight_matrix, dim=0) - - # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) - if hasattr(linear_layer, "input_size"): - self.num_channels = linear_layer.input_size # Megatron-Core - else: - self.num_channels = linear_layer.in_features # PyTorch - - self.agg_channel_activations = torch.zeros( - size=(self.num_channels,), - dtype=torch.float32, - device=weight_matrix.device, - ) - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple - ) -> None: - """Accumulate mean absolute activations per channel. - - Args: - module: The module this hook is registered on. - args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] - (PyTorch batch-first format). - output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) - for parallel layers. - """ - activations = args[0] - - # Don't aggregate activations from non-max subnets (e.g. from profiling) - if self.max_size is not None and activations.shape[-1] != self.max_size: - return - - mean_abs_channel_activations = ( - activations.abs().float().mean(dim=list(range(activations.ndim - 1))) - ) - self.agg_channel_activations[:] += mean_abs_channel_activations # shape [input_channels] - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert results to dict with channel importance scores. - - Returns: - Dict with "score" (weight_norm * activations), "weight_norm", and - "agg_channel_activations". - """ - return { - "score": (self.weight_norm * self.agg_channel_activations).cpu(), - "weight_norm": self.weight_norm.cpu(), - "agg_channel_activations": self.agg_channel_activations.cpu(), - } - - def accumulate(self) -> torch.Tensor: - """Return importance scores as a tensor. - - Returns: - Tensor of importance scores (weight_norm * activations), one per channel. - """ - return self.to_dict()["score"] - - def state_dict(self) -> dict: - """Save the internal state for checkpointing.""" - return { - "agg_channel_activations": self.agg_channel_activations.cpu().clone(), - "weight_norm": self.weight_norm.cpu().clone(), - } - - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint.""" - self.agg_channel_activations = state_dict["agg_channel_activations"].to( - self.agg_channel_activations.device - ) - # weight_norm should be the same as it's derived from the model weights - # but we can verify it matches - expected_weight_norm = state_dict["weight_norm"].to(self.weight_norm.device) - if not torch.allclose(self.weight_norm, expected_weight_norm, rtol=1e-5): - raise AssertionError( - "weight_norm mismatch during state loading - model weights may have changed" - ) - - -def get_pruning_schedule(num_channels, pruning_iters): - """Spending decreases monotonically when num_channels >= pruning_iters. - - Intervals between spends increase monotonically when pruning_iters > num_channels. - The budget is fully utilized, and there's spending in the last iteration. - num_channels = 10, pruning_iters = 4 ==> [3, 3, 2, 2] - num_channels = 4, pruning_iters = 10 ==> [0, 1, 0, 1, 0, 0, 1, 0, 0, 1] - """ - if num_channels >= pruning_iters: - # Case when budget is greater than or equal to iterations - q = num_channels // pruning_iters # Base spend per iteration - r = num_channels % pruning_iters # Remainder to distribute - - schedule = [] - for i in range(pruning_iters): - if i < r: - # Assign higher spend to earlier iterations - schedule.append(q + 1) - else: - schedule.append(q) - else: - # Case when iterations are greater than budget - schedule = [0] * pruning_iters - for i in range(1, num_channels + 1): - # Distribute spends at positions where intervals increase monotonically - pos = ((i * pruning_iters) // num_channels) - 1 - schedule[pos] = 1 - return schedule - - -class IterativeChannelContributionHook(ForwardHook): - """Hook for iterative channel pruning based on contribution analysis. - - Progressively identifies and removes the least important input channels of a linear layer - by measuring channel contribution as the L2 norm of output change when removed. - - Args: - linear_layer: The linear projection layer to analyze. Can be either nn.Linear or - RowParallelLinear from megatron.core.tensor_parallel.layers. - activation_hooks_kwargs: Configuration dict with: - - validation_full_iters (int): Number of pruning iterations. - - clear_gpu_memory (bool, optional): Clear GPU memory during computation. - - calibration_method (str, optional): "scale_by_magnitude" or None. - max_size: Optional maximum expected size to validate against (skips if mismatch). - Useful for skipping non-max subnets during profiling. - """ - - def __init__( - self, - linear_layer: nn.Linear | RowParallelLinear, - activation_hooks_kwargs: dict, - max_size: int | None = None, - ): - """Initialize the iterative channel contribution hook.""" - self.weight_matrix = linear_layer.weight - - # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) - # TODO: Consider better design to handle RowParallelLinear and nn.Linear - if hasattr(linear_layer, "input_size"): - self.num_channels = linear_layer.input_size # Megatron-Core - else: - self.num_channels = linear_layer.in_features # PyTorch - - self.max_size = max_size - self.pruning_iters = activation_hooks_kwargs["validation_full_iters"] - self.clear_gpu_memory = activation_hooks_kwargs.get("clear_gpu_memory", False) - self.curr_iter = 0 - self.pruning_schedule = get_pruning_schedule( - num_channels=self.num_channels, pruning_iters=self.pruning_iters - ) - - self.agg_cont_per_channel = torch.zeros( - size=(self.num_channels,), - dtype=torch.float32, - device=self.weight_matrix.device, - ) - self.pruned_channels = [] - self.calibration_method = activation_hooks_kwargs.get("calibration_method") - self.epsilon = 1e-8 - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple - ) -> None: - """Compute channel contributions and prune channels according to schedule. - - Args: - module: The module this hook is registered on. - args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] - (PyTorch batch-first format). - output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) - for parallel layers. - """ - # Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear) - # TODO: Consider better design to handle RowParallelLinear and nn.Linear - if isinstance(output, tuple): - output_tensor = output[0] - else: - output_tensor = output - - activations = args[0] - - # Don't aggregate activations from non-max subnets (e.g. from profiling) - if self.max_size is not None and activations.shape[-1] != self.max_size: - return - - n_channels_to_prune = self.pruning_schedule[self.curr_iter] - - curr_activations = activations.clone() # Shape B,T,I - curr_activations[..., self.pruned_channels] = 0 - output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E - - if self.calibration_method is None: - scaling_factor_per_token = torch.ones_like(output_tensor[..., 0]) # Shape B,T - elif self.calibration_method == "scale_by_magnitude": - output_norms = torch.linalg.vector_norm(output_tensor, dim=-1) # Shape B,T - output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T - scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon) - del output_curr_norms, output_norms - else: - raise NotImplementedError - del curr_activations - clear_gpu_memory(clear=self.clear_gpu_memory) - - s = scaling_factor_per_token.unsqueeze(-1) * output_tensor - output_curr # Shape: (B, T, E) - s_squared_per_token = torch.sum(s**2, dim=-1) # Shape: (B, T) - b = s @ self.weight_matrix # Shape: (B, T, I) - c = torch.sum(self.weight_matrix**2, dim=0) # Shape: (I) - del s, output_curr - clear_gpu_memory(clear=self.clear_gpu_memory) - - contribution_squared = ( - s_squared_per_token.unsqueeze(2) + 2 * activations * b + (activations**2) * c - ) # Shape: (B, T, I) - del s_squared_per_token, b, c, activations - clear_gpu_memory(clear=self.clear_gpu_memory) - - contribution = torch.sqrt(contribution_squared + self.epsilon) # Shape: (B, T, I) - mean_cont_per_channel = torch.mean(contribution, dim=(0, 1)) # Shape: (I) - mean_cont_per_channel[self.pruned_channels] = torch.inf - del contribution, contribution_squared - clear_gpu_memory(clear=self.clear_gpu_memory) - - self.agg_cont_per_channel += mean_cont_per_channel - if n_channels_to_prune > 0: - _, worst_indices = torch.topk( - self.agg_cont_per_channel, n_channels_to_prune, largest=False - ) - worst_indices_list = worst_indices.tolist() - assert not set(self.pruned_channels).intersection(set(worst_indices_list)) - self.pruned_channels.extend(worst_indices_list) - self.agg_cont_per_channel.zero_() - self.curr_iter += 1 - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert pruning results to dict with channel importance rankings. - - Returns: - Dict with "score" (importance rank per channel) and - "channels_importance_ascending" (channel indices in ascending importance). - """ - assert self.num_channels == len(self.pruned_channels) - channels_importance_ascending = torch.tensor(self.pruned_channels, dtype=torch.long) - score = torch.empty(self.num_channels, dtype=torch.long) - score[channels_importance_ascending] = torch.arange(self.num_channels, dtype=torch.long) - - return { - "score": score.cpu(), - "channels_importance_ascending": channels_importance_ascending.cpu(), - } - - def accumulate(self) -> torch.Tensor: - """Return importance scores as a tensor. - - Returns: - Tensor of importance scores, one per channel. Lower scores indicate less important channels. - """ - return self.to_dict()["score"] - - def state_dict(self) -> dict: - """Save the internal state for checkpointing.""" - return { - "curr_iter": self.curr_iter, - "pruned_channels": self.pruned_channels.copy(), - "agg_cont_per_channel": self.agg_cont_per_channel.cpu().clone(), - "num_channels": self.num_channels, - "pruning_iters": self.pruning_iters, - "pruning_schedule": self.pruning_schedule.copy(), - "calibration_method": self.calibration_method, - "epsilon": self.epsilon, - } - - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint.""" - self.curr_iter = state_dict["curr_iter"] - self.pruned_channels = state_dict["pruned_channels"].copy() - self.agg_cont_per_channel = state_dict["agg_cont_per_channel"].to(self.weight_matrix.device) - # Verify other parameters match - assert self.num_channels == state_dict["num_channels"], "Channel count mismatch" - assert self.pruning_iters == state_dict["pruning_iters"], "Iteration count mismatch" - assert self.pruning_schedule == state_dict["pruning_schedule"], "Pruning schedule mismatch" + return gather_from_tensor_model_parallel_region(args[0]).detach() diff --git a/tests/gpu/torch/_compress/test_compress.py b/tests/gpu/torch/_compress/test_compress.py index 997bb99719..24b8b8b2ec 100644 --- a/tests/gpu/torch/_compress/test_compress.py +++ b/tests/gpu/torch/_compress/test_compress.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from datetime import timedelta from functools import partial from pathlib import Path @@ -67,11 +66,7 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran # if rank == 0: # assertions for the score_pruning_activations step 1 - rank = int(os.environ["RANK"]) - rank_filepath = ( - f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" - ) - assert (puzzle_dir / rank_filepath).is_file() + _assert_score_pruning_activations(puzzle_dir) # assertions for the pruning_ckpts step 2 assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() @@ -99,4 +94,29 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran dist.cleanup() - print("PYTEST SUMMARY: test_compress_model() test has finished successfully") + print( + "PYTEST SUMMARY: test_compress_model() test has finished successfully. Puzzle directory: ", + puzzle_dir, + ) + + +def _assert_score_pruning_activations(puzzle_dir: Path): + """Assertions for the score_pruning_activations step 1.""" + rank = dist.rank() + rank_filepath = f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + assert (puzzle_dir / rank_filepath).is_file() + + pruning_scores = torch.load(puzzle_dir / rank_filepath) + + layer_names = list(pruning_scores.keys()) + assert len(layer_names) == 2 + + # Check specific values for layer 0 + layer_0 = pruning_scores[layer_names[0]] + assert layer_0["score"][0].item() == 371 + assert layer_0["channels_importance_ascending"][0].item() == 140 + + # Check specific values for layer 1 + layer_1 = pruning_scores[layer_names[1]] + assert layer_1["score"][0].item() == 269 + assert layer_1["channels_importance_ascending"][0].item() == 366 diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks.py b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py similarity index 78% rename from tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks.py rename to tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py index f94f2e85f4..aa73a3be19 100644 --- a/tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks.py +++ b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py @@ -13,21 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for megatron hooks.""" +"""Unit tests for base hooks.""" import torch import torch.nn as nn -from _test_utils.import_helper import skip_if_no_megatron -skip_if_no_megatron() - -from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from megatron.core.parallel_state import initialize_model_parallel - -from modelopt.torch.nas.plugins.megatron_hooks import ( - IterativeChannelContributionHook, - MegatronL2NormHook, -) +from modelopt.torch.nas.plugins.megatron_hooks import IterativeChannelContributionHook, L2NormHook def _test_iterative_channel_contribution_hook_with_shape(dim1: int, dim2: int): @@ -81,15 +72,12 @@ def test_iterative_channel_contribution_hook_bsi(): _test_iterative_channel_contribution_hook_with_shape(dim1=8, dim2=32) -def _test_l2_norm_hook(rank, size): - """Internal test function that runs in spawned process with distributed setup.""" - # Initialize Megatron parallel state (distributed is already initialized by spawn_multiprocess_job) - initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) - +def test_l2_norm_hook(): + """Test L2NormHook returns correct scores after accumulating activations.""" torch.manual_seed(42) linear_layer = nn.Linear(in_features=6, out_features=4, bias=False) - hook = MegatronL2NormHook(max_size=None) + hook = L2NormHook(max_size=None) linear_layer.register_forward_hook(hook) num_iterations = 3 @@ -110,12 +98,3 @@ def _test_l2_norm_hook(rank, size): assert torch.allclose(scores, expected_scores, atol=1e-4), ( f"Expected scores {expected_scores}, got {scores}" ) - - -def test_l2_norm_hook(): - """Test MegatronL2NormHook returns correct scores after accumulating activations.""" - spawn_multiprocess_job( - size=1, - job=_test_l2_norm_hook, - backend="gloo", - ) diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks_analysis.py b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py similarity index 75% rename from tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks_analysis.py rename to tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py index 4f075c9dd2..954c6e11c7 100644 --- a/tests/gpu/torch/nas/plugins/megatron_hooks/test_megatron_hooks_analysis.py +++ b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py @@ -13,22 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for megatron hooks analysis tools.""" +"""Unit tests for base hooks analysis tools.""" import pytest import torch import torch.nn as nn -from _test_utils.import_helper import skip_if_no_megatron - -skip_if_no_megatron() - -from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from megatron.core.parallel_state import initialize_model_parallel from modelopt.torch.nas.plugins.megatron_hooks import ( IndependentChannelContributionHook, IterativeChannelContributionHook, - MegatronL2NormHook, + L2NormHook, evaluate_importance_scores, ) @@ -54,16 +48,13 @@ def test_evaluate_importance_scores_basic(): assert metrics["cosine_similarity"] == pytest.approx(0.77117118, rel=1e-5) -def _test_evaluate_importance_scores_with_l2_norm_hook(rank, size): - """Test evaluate_importance_scores with MegatronL2NormHook.""" - # Initialize Megatron parallel state - initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) - +def test_evaluate_importance_scores_with_l2_norm_hook(): + """Test evaluate_importance_scores with L2NormHook.""" torch.manual_seed(42) # Create layer and hook layer = nn.Linear(in_features=50, out_features=30, bias=False) - hook = MegatronL2NormHook(max_size=None) + hook = L2NormHook(max_size=None) # Run evaluation metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) @@ -76,11 +67,8 @@ def _test_evaluate_importance_scores_with_l2_norm_hook(rank, size): assert metrics["cosine_similarity"] == pytest.approx(0.7814186, rel=1e-5) -def _test_evaluate_importance_scores_with_iterative_channel_contribution_hook(rank, size): +def test_evaluate_importance_scores_with_iterative_channel_contribution_hook(): """Test evaluate_importance_scores with IterativeChannelContributionHook.""" - # Initialize Megatron parallel state - initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) - torch.manual_seed(42) # Create layer and hook @@ -103,11 +91,8 @@ def _test_evaluate_importance_scores_with_iterative_channel_contribution_hook(ra assert metrics["cosine_similarity"] == pytest.approx(0.8110392, rel=1e-5) -def _test_evaluate_importance_scores_with_independent_channel_contribution_hook(rank, size): +def test_evaluate_importance_scores_with_independent_channel_contribution_hook(): """Test evaluate_importance_scores with IndependentChannelContributionHook.""" - # Initialize Megatron parallel state - initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) - torch.manual_seed(42) # Create layer and hook @@ -125,33 +110,6 @@ def _test_evaluate_importance_scores_with_independent_channel_contribution_hook( assert metrics["cosine_similarity"] == pytest.approx(0.8116209, rel=1e-5) -def test_evaluate_importance_scores_with_l2_norm_hook(): - """Test evaluate_importance_scores using MegatronL2NormHook.""" - spawn_multiprocess_job( - size=1, - job=_test_evaluate_importance_scores_with_l2_norm_hook, - backend="gloo", - ) - - -def test_evaluate_importance_scores_with_iterative_channel_contribution_hook(): - """Test evaluate_importance_scores using IterativeChannelContributionHook.""" - spawn_multiprocess_job( - size=1, - job=_test_evaluate_importance_scores_with_iterative_channel_contribution_hook, - backend="gloo", - ) - - -def test_evaluate_importance_scores_with_independent_channel_contribution_hook(): - """Test evaluate_importance_scores using IndependentChannelContributionHook.""" - spawn_multiprocess_job( - size=1, - job=_test_evaluate_importance_scores_with_independent_channel_contribution_hook, - backend="gloo", - ) - - def _run_hook_and_evaluate( layer: nn.Linear, hook, @@ -174,9 +132,7 @@ def _run_hook_and_evaluate( # Run forward passes all_activations = [] for _ in range(num_iterations): - activations = torch.randn( - 16, 8, layer.in_features - ) # seq=16, batch=8, in_features=50 (Megatron format) + activations = torch.randn(16, 8, layer.in_features) # seq=16, batch=8, in_features=50 all_activations.append(activations) _ = layer(activations) From db866d90638ef95e53822ebd2a354cb565eff397 Mon Sep 17 00:00:00 2001 From: Liana Mikaelyan <45925959+LianaMikael@users.noreply.github.com> Date: Mon, 22 Dec 2025 12:07:01 +0000 Subject: [PATCH 27/62] Clean up Puzzle Compress Tutorial (#711) ## What does this PR do? This PR updates the Puzzle compress tutorial: - remove verbose logs - add more information on how to choose config parameters - fix unused import --------- Signed-off-by: Liana Mikaelyan Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/compress/README.md | 63 +++++++++++++++---- .../Llama-3_1-8B.yaml | 5 +- .../llama-3_1-8B_pruneffn_memory.yaml | 2 +- .../validate_model_defaults.yaml | 2 +- .../decilm/deci_lm_hf_code/modeling_decilm.py | 3 + .../tools/bypassed_training/child_init.py | 2 - ...validate_puzzle_with_multi_replacements.py | 4 +- .../nas/plugins/megatron_hooks/__init__.py | 6 +- 8 files changed, 65 insertions(+), 22 deletions(-) diff --git a/examples/compress/README.md b/examples/compress/README.md index 755b6090e8..42e55892e5 100644 --- a/examples/compress/README.md +++ b/examples/compress/README.md @@ -9,7 +9,7 @@ The supported modifications are: To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements. -In this example, we compress the [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. +In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. ## Environment @@ -21,17 +21,15 @@ pip install -e .[hf,compress] - For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use s single GPU. -## Compress the Model - -1. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file. +- To make use of [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) and [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), you need to accept the terms and conditions for the corresponding model and the dataset in the Huggingface Hub. Log in to the Huggingface Hub and enter your HF token. - **_NOTE:_** - How to choose `intermediate_size_list`? - The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a 256) to avoid issues with tensor operations in subsequent steps. +```bash +hf auth login +``` - Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` MiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements. +## Compress the Model -2. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). +1. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). dataset split: "code", "math", "stem", "chat", excluding reasoning samples (2.62GB) @@ -39,10 +37,24 @@ pip install -e .[hf,compress] python -m modelopt.torch._compress.dataset.prepare_dataset --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 --output_dir path/to/Nemotron-Post-Training-Dataset-v2 ``` +2. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file. + + - `puzzle_dir` indicates a new directory for saving the resulting model. + - `input_hf_model_path` indicates the local directory with the input model checkpoint. + - `dataset_path` indicates the directory with the dataset downloaded earlier. + + **_NOTE:_** + How to choose `intermediate_size_list`? + The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a 256) to avoid issues with tensor operations in subsequent steps. + + Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` MiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements. + + We can also set the target size of the resulting model using `num_params = 7_000_000_000`. This will be used as an upper bound for the number of parameters of the model. + 3. Run the compression script. ```bash - torchrun --nproc_per_node 2 examples/compress/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Compress Progress" + torchrun --nproc_per_node 2 examples/compress/main.py --config examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Compress Progress" ``` This will save the full output to `log.txt` and display the following progress on screen: @@ -110,7 +122,7 @@ pip install -e .[hf,compress] Average losses = {'lm_loss': 1.7577573340386152, 'token_accuracy_top_1': 0.6225490570068359, 'token_accuracy_top_5': 0.846257209777832, 'token_accuracy_top_10': 0.8987817764282227} ``` - 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). Let's rerun MIP search aiming for 15% memory reduction. + 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). ## Re-run MIP Search with different constraints @@ -194,6 +206,31 @@ lm_eval --model hf \ --batch_size 4 ``` -## Advanced usage +## Inference Performance Benchmarking + +Now let's evaluate how much speedup we get with the compressed model in terms of throughput and latency. + +- Install [vLLM from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source). +- Rearrange the model safetensors to be used for vLLM. + +```bash +cd path/to/model +mv subblocks_safetensors/* . +sed -i 's+subblocks_safetensors/++g' model.safetensors.index.json +``` + +- Benchmark latency + +```bash +vllm bench latency --model path/to/model --load-format safetensors --trust-remote-code +``` + +- Benchmark throughput + +```bash +vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --load-format safetensors --trust-remote-code +``` + +## Advanced Usage -Modify `path/to/Llama-3_1-8B yaml` file for advanced compression scenarios. +Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios. diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml index 133fe0b777..7045e0d002 100644 --- a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -9,7 +9,7 @@ defaults: puzzle_dir: ??? teacher_dir: ${puzzle_dir}/ckpts/teacher/ replacement_library_path: ${puzzle_dir}/replacement_library.json -dataset_path: ??? # path to v0.4_mini +dataset_path: ??? # ppath to Nemotron-Post-Training-Dataset-v2 skip_realize_model: false @@ -40,7 +40,7 @@ scoring: teacher_dir: ${to_path:${teacher_dir}} output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation - eval_samples: 10 # default is 128 + eval_samples: 128 micro_batch_size: 1 seed: 42 shuffle_seed: 444 @@ -77,6 +77,7 @@ mip: human_constraints: target_memory: 78_000 + num_params: 7_000_000_000 mip_constraints: metric_overrides: diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml index cfd7f93e81..c9a0cabf30 100644 --- a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml @@ -14,7 +14,7 @@ puzzle_dir: /workspace/puzzle_dir # MIP memory constraint (in MiB) mip: human_constraints: - target_memory: 96_000 # 96 GiB + target_memory: 78_000 # 78 GiB # FFN intermediate sizes to search over (heterogeneous architecture) pruning: diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml index 9e662c4e13..202af6eb02 100644 --- a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -14,4 +14,4 @@ write_results: false calc_losses_on_cpu: false activations_log_dir: model_name_or_path: -load_dataset_fn: ${get_object:utils.data.dataloaders.load_from_disk_fn} +load_dataset_fn: ${get_object:modelopt.torch._compress.utils.data.dataloaders.load_from_disk_fn} diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py index 808533d7f8..22d00ea773 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py @@ -1020,6 +1020,9 @@ def __init__(self, config: DeciLMConfig, layer_idx: int | tuple[int, ...]): self.ffn_config = self.block_config.ffn self.layer_idx = layer_idx + if not config._attn_implementation: + config._attn_implementation = "eager" + if not self.attention_config.no_op: self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.attention_config.replace_with_linear: diff --git a/modelopt/torch/_compress/tools/bypassed_training/child_init.py b/modelopt/torch/_compress/tools/bypassed_training/child_init.py index 3e2c42f09c..1bd36fa090 100644 --- a/modelopt/torch/_compress/tools/bypassed_training/child_init.py +++ b/modelopt/torch/_compress/tools/bypassed_training/child_init.py @@ -471,7 +471,6 @@ def create_child_state_dict( copy_start_time = time.time() keys_to_copy_from_orig_model = set(keys.values()) - ignored_keys for key in keys_to_copy_from_orig_model: - aprint(f"copying {key} from original_state_dict") # Memory optimization: avoid unnecessary copies tensor = original_state_dict[key] if not tensor.is_contiguous(): @@ -877,7 +876,6 @@ def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None: if len(ACTIVATIONS_LOG) == 0: assert "activations_log_dir" in mlp_init_config activations_log_dir = mlp_init_config["activations_log_dir"] - print(f"Loading activations_log from {activations_log_dir}") ACTIVATIONS_LOG.update( { module_name: module_log diff --git a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py index ca02998684..51ee168d60 100644 --- a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py @@ -158,8 +158,8 @@ def validate_puzzle_solutions(args: DictConfig) -> None: list(zip(args.solutions_to_validate, puzzle_solutions)), desc="Validating solutions" ): layer_replacements = _extract_layer_replacements_from_puzzle_solution(puzzle_solution) - realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) - # realizable_as_symlinks = False + # realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) + realizable_as_symlinks = False model_config = replacement_library.create_model_config(layer_replacements) if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): model = replacement_library.load_model(layer_replacements) diff --git a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py index 0ba4405183..996d531392 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py @@ -14,6 +14,10 @@ # limitations under the License. """Forward hooks for estimating importance scores for pruning.""" +from modelopt.torch.utils import import_plugin + from .base_hooks import * from .base_hooks_analysis import * -from .megatron_hooks import * + +with import_plugin("megatron_hooks"): + from .megatron_hooks import * From 2e813bfb5671849f845ccd6d68475cce93c67c8d Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 22 Dec 2025 18:03:34 +0100 Subject: [PATCH 28/62] Two bug fixes: mix checkpointing and dtype (#718) ## What does this PR do? Two bug fixes: 1) Saving mip checkpoint (correctly saving model code files) 2) Passing dtype as object instead of string to calculate_losses_pipeline and load_and_shard_model Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/tools/validate_model.py | 11 +++-------- .../tools/validate_puzzle_with_multi_replacements.py | 12 +++--------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index d3d71a4198..456f9fab87 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -18,7 +18,7 @@ the loss, and optionally registers hooks to capture the inputs and the outputs of pytorch modules that are used for activation scoring for pruning. -TODO: Consider moving this a separate module dedicated for scoring. +TODO: Consider moving this a separate module dedicated for scoring """ import textwrap @@ -130,11 +130,6 @@ def validate_model( - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. Returns (None, None) if not on master rank. """ - # convert model_dtype and autocast_dtype from string to torch.dtype - if isinstance(args.model_dtype, str): - args.model_dtype = getattr(torch, args.model_dtype.strip("torch.")) - if isinstance(args.autocast_dtype, str): - args.autocast_dtype = getattr(torch, args.autocast_dtype.strip("torch.")) if val_dataloader is None: val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None @@ -199,7 +194,7 @@ def validate_model( calc_on_cpu=args.calc_losses_on_cpu, just_model_forward=just_model_forward, checkpoint_manager=checkpoint_manager, - autocast_dtype=args.autocast_dtype, + autocast_dtype=getattr(torch, args.autocast_dtype.strip("torch.")), ) if losses is not None: @@ -232,7 +227,7 @@ def prepare_model( model = load_and_shard_model( args.model_name_or_path, model_config_overrides={"block_size": args.block_size}, - model_dtype=args.model_dtype, + model_dtype=getattr(torch, args.model_dtype.strip("torch.")), ) else: try: diff --git a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py index 51ee168d60..6bc4d11b35 100644 --- a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py @@ -15,7 +15,7 @@ """Validates puzzle solutions by applying layer replacements and evaluating model performance. -TODO: Consider moving this a separate module dedicated for scoring. +TODO: Consider moving this a separate module dedicated for scoring """ # mypy: ignore-errors @@ -42,6 +42,7 @@ copy_tokenizer, ) from modelopt.torch._compress.tools.checkpoint_utils_hf import ( + copy_deci_lm_hf_code, save_checkpoint, save_safetensors_index, ) @@ -182,7 +183,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None: save_checkpoint(model, checkpoint_dir) copy_tokenizer(args.tokenizer_name, checkpoint_dir) - copy_hf_code(checkpoint_dir) + copy_deci_lm_hf_code(checkpoint_dir) dist.barrier() @@ -246,13 +247,6 @@ def save_checkpoint_as_symlinks( ) -def copy_hf_code(checkpoint_dir: Path) -> None: - code_dir = Path(__file__).parent / "deci_lm_hf_code" - print(f"copying hf code from {code_dir} ") - for file in code_dir.glob("*.py"): - shutil.copy(file, checkpoint_dir / file.name) - - def _load_tokenizer(args: DictConfig) -> PreTrainedTokenizerBase: tokenizer = None if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: From 0eecfc6e9fb3939b6cb73a36fefbef56a643f963 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 13 Jan 2026 23:50:35 +0530 Subject: [PATCH 29/62] Fix test assertions for 2-gpu (#772) - Assertions should account for layers split across PP ranks Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- tests/gpu/torch/_compress/test_compress.py | 32 ++++++++++--------- .../test_mcore_gpt_minitron_pruning.py | 16 +++++----- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/tests/gpu/torch/_compress/test_compress.py b/tests/gpu/torch/_compress/test_compress.py index 24b8b8b2ec..dd6e0eb5a3 100644 --- a/tests/gpu/torch/_compress/test_compress.py +++ b/tests/gpu/torch/_compress/test_compress.py @@ -35,7 +35,7 @@ def test_compress(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( - size=torch.cuda.device_count(), + size=min(torch.cuda.device_count(), 2), # assertions configured for atmost 2 GPUs job=partial(_test_compress_multiprocess_job, project_root_path, tmp_path), backend="nccl", ) @@ -64,10 +64,9 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran # # Check assertions # + # assertions for the score_pruning_activations step 1 + _assert_score_pruning_activations(puzzle_dir) if rank == 0: - # assertions for the score_pruning_activations step 1 - _assert_score_pruning_activations(puzzle_dir) - # assertions for the pruning_ckpts step 2 assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() @@ -103,20 +102,23 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran def _assert_score_pruning_activations(puzzle_dir: Path): """Assertions for the score_pruning_activations step 1.""" rank = dist.rank() + size = dist.size() rank_filepath = f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" assert (puzzle_dir / rank_filepath).is_file() pruning_scores = torch.load(puzzle_dir / rank_filepath) layer_names = list(pruning_scores.keys()) - assert len(layer_names) == 2 - - # Check specific values for layer 0 - layer_0 = pruning_scores[layer_names[0]] - assert layer_0["score"][0].item() == 371 - assert layer_0["channels_importance_ascending"][0].item() == 140 - - # Check specific values for layer 1 - layer_1 = pruning_scores[layer_names[1]] - assert layer_1["score"][0].item() == 269 - assert layer_1["channels_importance_ascending"][0].item() == 366 + assert len(layer_names) == 2 // size + + if size == 1 or rank == 0: + # Check specific values for layer 0 + layer_0 = pruning_scores[layer_names[0]] + assert layer_0["score"][0].item() == 371 + assert layer_0["channels_importance_ascending"][0].item() == 140 + + if size == 1 or rank == 1: + # Check specific values for layer 1 + layer_1 = pruning_scores[layer_names[1 if size == 1 else 0]] + assert layer_1["score"][0].item() == 269 + assert layer_1["channels_importance_ascending"][0].item() == 366 diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index e58473e8ac..46d48ea2b2 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -297,40 +297,40 @@ def forward_loop(m): # TODO: Simplify it: this unit test is too long, # hard to read (the same set of assertions across different test cases with if-else). - assert len(pruning_scores["activations_per_rank"]) == 1 - rank_0_activations = pruning_scores["activations_per_rank"][0] + assert len(pruning_scores["activations_per_rank"]) == size + activations = pruning_scores["activations_per_rank"][rank] # Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4) - if pruned_ffn_div == 4: + if size == 1 and pruned_ffn_div == 4: # Layer scores _assert_approx(pruning_scores["layer_scores"], {1: 0.028923, 2: 0.046508}) # Validate decoder.layers.0.mlp activations - mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"] + mlp_0_acts = activations["decoder.layers.0.mlp"] _assert_approx(mlp_0_acts.min().item(), 0.000026) _assert_approx(mlp_0_acts.max().item(), 0.000729) _assert_approx(mlp_0_acts.mean().item(), 0.000201) # Validate decoder.layers.1.mlp activations - mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"] + mlp_1_acts = activations["decoder.layers.1.mlp"] _assert_approx(mlp_1_acts.min().item(), 0.000022) _assert_approx(mlp_1_acts.max().item(), 0.000762) _assert_approx(mlp_1_acts.mean().item(), 0.000162) # Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2) - elif pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1: + elif size == 1 and pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1: # Layer scores _assert_approx(pruning_scores["layer_scores"], {1: 0.028056, 2: 0.038353}) # Validate decoder.layers.0.self_attention activations - attn_0_acts = rank_0_activations["decoder.layers.0.self_attention"] + attn_0_acts = activations["decoder.layers.0.self_attention"] assert attn_0_acts.shape == torch.Size([hidden_size]) _assert_approx(attn_0_acts.min().item(), 0.010091) _assert_approx(attn_0_acts.max().item(), 0.023826) _assert_approx(attn_0_acts.mean().item(), 0.014548) # Validate decoder.layers.1.self_attention activations - attn_1_acts = rank_0_activations["decoder.layers.1.self_attention"] + attn_1_acts = activations["decoder.layers.1.self_attention"] assert attn_1_acts.shape == torch.Size([hidden_size]) _assert_approx(attn_1_acts.min().item(), 0.009982) _assert_approx(attn_1_acts.max().item(), 0.035644) From 43b3cfa0205b2da2f590ec26048473f9c9120168 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 14 Jan 2026 18:07:59 +0530 Subject: [PATCH 30/62] Rename compress to puzzletron (#776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? - As per slack discussion, rename `modelopt.torch._compress` to `modelopt.torch.puzzletron` - Auto-fix some formatting (previously skipped because folder was private (`_compress`) - Fix some doc building errors ## Summary by CodeRabbit * **New Features** * Introduced Puzzletron as the primary optimization algorithm with updated module structure and comprehensive documentation. * **Bug Fixes & Improvements** * Modernized type annotations and improved code quality. * **Documentation** * Updated examples and tutorials to reflect Puzzletron functionality. * Added new Puzzletron configuration examples and step-by-step guides. * **Chores** * Updated dependencies in setup.py to support Puzzletron. * Reorganized test utilities and resources. ✏️ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/CODEOWNERS | 2 +- .pre-commit-config.yaml | 12 +- examples/pruning/README.md | 3 +- examples/{compress => puzzletron}/README.md | 30 ++--- .../Llama-3_1-8B.yaml | 0 .../llama-3_1-8B_pruneffn_memory.yaml | 6 +- .../pruning/attn_pruning.yaml | 0 .../pruning/ffn_pruning.yaml | 0 .../pruning/hidden_dim_pruning.yaml | 0 .../pruning/pruning_defaults.yaml | 0 .../validate_model_defaults.yaml | 2 +- .../validate_solutions_defaults.yaml | 0 examples/{compress => puzzletron}/main.py | 36 +++--- .../nas/plugins/megatron_hooks/base_hooks.py | 4 +- .../torch/{_compress => puzzletron}/README.md | 0 .../{_compress => puzzletron}/__init__.py | 0 .../activation_hooks/__init__.py | 0 .../activation_hooks/utils.py | 5 +- .../score_pruning_activations.py | 13 +- .../build_library_and_stats.py | 17 +-- .../dataset/__init__.py | 0 .../dataset/prepare_dataset.py | 2 +- .../decilm/conversion_utils.py | 0 .../converters/convert_llama3_to_decilm.py | 11 +- .../decilm/deci_lm_hf_code/__init__.py | 0 .../decilm/deci_lm_hf_code/block_config.py | 0 .../deci_lm_hf_code/configuration_decilm.py | 0 .../megatron_lm__mamba_mixer.py | 0 .../megatron_lm__megatron_tokenizer.py | 0 .../deci_lm_hf_code/megatron_lm__tokenizer.py | 0 .../decilm/deci_lm_hf_code/modeling_decilm.py | 0 .../deci_lm_hf_code/tokenization_decilm.py | 0 .../transformers_4_44_2__activations.py | 0 .../transformers_4_44_2__cache_utils.py | 0 ...ransformers_4_44_2__configuration_llama.py | 0 ...ormers_4_44_2__modeling_attn_mask_utils.py | 0 ...g_flash_attention_utils_backward_compat.py | 0 .../transformers_4_44_2__modeling_outputs.py | 0 ...ransformers_4_44_2__modeling_rope_utils.py | 0 .../transformers_4_44_2__pytorch_utils.py | 0 .../transformers_4_51_3__cache_utils.py | 0 ...ansformers_4_51_3__configuration_llama4.py | 0 ...rmers_4_51_3__modeling_llama4_attention.py | 0 .../decilm/deci_lm_hf_code/variable_cache.py | 0 .../decilm/deci_lm_hf_code/vllm_yarn_utils.py | 0 .../mip/mip_and_realize_models.py | 9 +- .../mip/mip_with_multi_layer_replacements.py | 16 +-- .../mip/run_puzzle.py | 29 +++-- .../{_compress => puzzletron}/mip/utils.py | 0 .../nas/plugins/puzzletron_nas_plugin.py} | 76 ++++++----- .../pruning/pruning_ckpts.py | 38 +++--- .../compress.py => puzzletron/puzzletron.py} | 25 ++-- .../build_replacement_library.py | 47 +++---- .../replacement_library.py | 20 ++- .../replacement_library/replacement_utils.py | 11 +- .../scoring/scoring.py | 8 +- .../sewing_kit/__init__.py | 0 .../sewing_kit/core.py | 2 + .../sewing_kit/passage/__init__.py | 0 .../sewing_kit/passage/core.py | 2 + .../sewing_kit/utils.py | 0 .../calc_subblock_params_and_memory.py | 17 ++- .../subblock_stats/calc_subblock_stats.py | 42 +++---- .../tools/__init__.py | 0 .../tools/bypassed_training/child_init.py | 102 +++++++-------- .../init_child_from_parent.py | 31 +++-- .../tools/checkpoint_utils.py | 16 +-- .../tools/checkpoint_utils_hf.py | 31 ++--- .../{_compress => puzzletron}/tools/common.py | 0 .../tools/hydra_utils.py | 0 .../tools/kd_model.py | 0 .../{_compress => puzzletron}/tools/logger.py | 16 +-- .../tools/post_init_sparse.py | 10 +- .../tools/robust_json.py | 0 .../tools/sharded_checkpoint_utils.py | 25 ++-- .../tools/validate_model.py | 99 ++++++++------- ...validate_puzzle_with_multi_replacements.py | 119 +++++++++--------- .../tools/validation_utils.py | 16 +-- .../utils/checkpoint_manager.py | 27 ++-- .../utils/data/dataloaders.py | 8 +- .../utils/data/dataset.py | 20 ++- .../utils/parsing.py | 0 .../{_compress => puzzletron}/utils/utils.py | 17 +-- .../utils/validate_runtime_pipeline.py | 20 ++- .../utils/validation.py | 24 ++-- pyproject.toml | 2 +- setup.py | 4 +- .../configs/Llama-3_1-8B-attn-pruning.yaml | 0 .../configs/Llama-3_1-8B-ffn-pruning.yaml | 0 .../configs/pruning/attn_pruning.yaml | 0 .../configs/pruning/ffn_pruning.yaml | 0 .../configs/pruning/hidden_dim_pruning.yaml | 0 .../configs/pruning/pruning_defaults.yaml | 0 .../configs/validate_model_defaults.yaml | 2 +- .../configs/validate_solutions_defaults.yaml | 0 .../tokenizer/special_tokens_map.json | 0 .../resources/tokenizer/tokenizer.json | 0 .../resources/tokenizer/tokenizer_config.json | 0 .../resources/tokenizer/truncate_tokenizer.py | 0 .../torch/puzzletron/utils.py} | 6 +- .../{_compress => puzzletron}/conftest.py | 0 ..._convert_llama3_config_to_decilm_config.py | 7 +- .../nas/plugins/test_nas_convert.py | 20 ++- .../nas/plugins/test_nas_search.py | 12 +- .../test_puzzletron.py} | 25 ++-- 105 files changed, 518 insertions(+), 626 deletions(-) rename examples/{compress => puzzletron}/README.md (86%) rename examples/{compress => puzzletron}/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml (100%) rename examples/{compress => puzzletron}/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml (72%) rename examples/{compress => puzzletron}/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml (100%) rename examples/{compress => puzzletron}/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml (100%) rename examples/{compress => puzzletron}/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml (100%) rename examples/{compress => puzzletron}/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml (100%) rename examples/{compress => puzzletron}/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml (80%) rename examples/{compress => puzzletron}/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml (100%) rename examples/{compress => puzzletron}/main.py (78%) rename modelopt/torch/{_compress => puzzletron}/README.md (100%) rename modelopt/torch/{_compress => puzzletron}/__init__.py (100%) rename modelopt/torch/{_compress => puzzletron}/activation_scoring/activation_hooks/__init__.py (100%) rename modelopt/torch/{_compress => puzzletron}/activation_scoring/activation_hooks/utils.py (97%) rename modelopt/torch/{_compress => puzzletron}/activation_scoring/score_pruning_activations.py (92%) rename modelopt/torch/{_compress => puzzletron}/build_library_and_stats.py (78%) rename modelopt/torch/{_compress => puzzletron}/dataset/__init__.py (100%) rename modelopt/torch/{_compress => puzzletron}/dataset/prepare_dataset.py (97%) rename modelopt/torch/{_compress => puzzletron}/decilm/conversion_utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/converters/convert_llama3_to_decilm.py (93%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/__init__.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/block_config.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/configuration_decilm.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/modeling_decilm.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/tokenization_decilm.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/variable_cache.py (100%) rename modelopt/torch/{_compress => puzzletron}/decilm/deci_lm_hf_code/vllm_yarn_utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/mip/mip_and_realize_models.py (89%) rename modelopt/torch/{_compress => puzzletron}/mip/mip_with_multi_layer_replacements.py (95%) rename modelopt/torch/{_compress => puzzletron}/mip/run_puzzle.py (96%) rename modelopt/torch/{_compress => puzzletron}/mip/utils.py (100%) rename modelopt/torch/{_compress/nas/plugins/compress_nas_plugin.py => puzzletron/nas/plugins/puzzletron_nas_plugin.py} (71%) rename modelopt/torch/{_compress => puzzletron}/pruning/pruning_ckpts.py (91%) rename modelopt/torch/{_compress/compress.py => puzzletron/puzzletron.py} (74%) rename modelopt/torch/{_compress => puzzletron}/replacement_library/build_replacement_library.py (93%) rename modelopt/torch/{_compress => puzzletron}/replacement_library/replacement_library.py (96%) rename modelopt/torch/{_compress => puzzletron}/replacement_library/replacement_utils.py (91%) rename modelopt/torch/{_compress => puzzletron}/scoring/scoring.py (91%) rename modelopt/torch/{_compress => puzzletron}/sewing_kit/__init__.py (100%) rename modelopt/torch/{_compress => puzzletron}/sewing_kit/core.py (99%) rename modelopt/torch/{_compress => puzzletron}/sewing_kit/passage/__init__.py (100%) rename modelopt/torch/{_compress => puzzletron}/sewing_kit/passage/core.py (99%) rename modelopt/torch/{_compress => puzzletron}/sewing_kit/utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/subblock_stats/calc_subblock_params_and_memory.py (95%) rename modelopt/torch/{_compress => puzzletron}/subblock_stats/calc_subblock_stats.py (93%) rename modelopt/torch/{_compress => puzzletron}/tools/__init__.py (100%) rename modelopt/torch/{_compress => puzzletron}/tools/bypassed_training/child_init.py (95%) rename modelopt/torch/{_compress => puzzletron}/tools/bypassed_training/init_child_from_parent.py (87%) rename modelopt/torch/{_compress => puzzletron}/tools/checkpoint_utils.py (92%) rename modelopt/torch/{_compress => puzzletron}/tools/checkpoint_utils_hf.py (94%) rename modelopt/torch/{_compress => puzzletron}/tools/common.py (100%) rename modelopt/torch/{_compress => puzzletron}/tools/hydra_utils.py (100%) rename modelopt/torch/{_compress => puzzletron}/tools/kd_model.py (100%) rename modelopt/torch/{_compress => puzzletron}/tools/logger.py (92%) rename modelopt/torch/{_compress => puzzletron}/tools/post_init_sparse.py (94%) rename modelopt/torch/{_compress => puzzletron}/tools/robust_json.py (100%) rename modelopt/torch/{_compress => puzzletron}/tools/sharded_checkpoint_utils.py (94%) rename modelopt/torch/{_compress => puzzletron}/tools/validate_model.py (73%) rename modelopt/torch/{_compress => puzzletron}/tools/validate_puzzle_with_multi_replacements.py (70%) rename modelopt/torch/{_compress => puzzletron}/tools/validation_utils.py (88%) rename modelopt/torch/{_compress => puzzletron}/utils/checkpoint_manager.py (93%) rename modelopt/torch/{_compress => puzzletron}/utils/data/dataloaders.py (97%) rename modelopt/torch/{_compress => puzzletron}/utils/data/dataset.py (96%) rename modelopt/torch/{_compress => puzzletron}/utils/parsing.py (100%) rename modelopt/torch/{_compress => puzzletron}/utils/utils.py (97%) rename modelopt/torch/{_compress => puzzletron}/utils/validate_runtime_pipeline.py (92%) rename modelopt/torch/{_compress => puzzletron}/utils/validation.py (97%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/configs/Llama-3_1-8B-attn-pruning.yaml (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/configs/Llama-3_1-8B-ffn-pruning.yaml (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/configs/pruning/attn_pruning.yaml (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/configs/pruning/ffn_pruning.yaml (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/configs/pruning/hidden_dim_pruning.yaml (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/configs/pruning/pruning_defaults.yaml (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/configs/validate_model_defaults.yaml (76%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/configs/validate_solutions_defaults.yaml (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/tokenizer/special_tokens_map.json (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/tokenizer/tokenizer.json (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/tokenizer/tokenizer_config.json (100%) rename tests/{gpu/torch/_compress => _test_utils/torch/puzzletron}/resources/tokenizer/truncate_tokenizer.py (100%) rename tests/{gpu/torch/_compress/compress_test_utils.py => _test_utils/torch/puzzletron/utils.py} (96%) rename tests/gpu/torch/{_compress => puzzletron}/conftest.py (100%) rename tests/gpu/torch/{_compress => puzzletron}/decilm/converters/test_convert_llama3_config_to_decilm_config.py (90%) rename tests/gpu/torch/{_compress => puzzletron}/nas/plugins/test_nas_convert.py (86%) rename tests/gpu/torch/{_compress => puzzletron}/nas/plugins/test_nas_search.py (89%) rename tests/gpu/torch/{_compress/test_compress.py => puzzletron/test_puzzletron.py} (83%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 12118e20b3..15376996b6 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -17,7 +17,6 @@ modelopt/deploy @NVIDIA/modelopt-deploy-codeowners modelopt/onnx @NVIDIA/modelopt-onnx-codeowners modelopt/onnx/autocast @NVIDIA/modelopt-onnx-autocast-codeowners modelopt/torch @NVIDIA/modelopt-torch-codeowners -modelopt/torch/_compress @NVIDIA/modelopt-torch-compress-codeowners modelopt/torch/_deploy @NVIDIA/modelopt-torch-deploy-codeowners modelopt/torch/distill @NVIDIA/modelopt-torch-distill-codeowners modelopt/torch/export @NVIDIA/modelopt-torch-export-codeowners @@ -25,6 +24,7 @@ modelopt/torch/nas @NVIDIA/modelopt-torch-nas-prune-codeowners modelopt/torch/opt @NVIDIA/modelopt-torch-opt-codeowners modelopt/torch/peft @NVIDIA/modelopt-torch-peft-codeowners modelopt/torch/prune @NVIDIA/modelopt-torch-nas-prune-codeowners +modelopt/torch/puzzletron @NVIDIA/modelopt-torch-puzzletron-codeowners modelopt/torch/quantization @NVIDIA/modelopt-torch-quantization-codeowners modelopt/torch/sparsity @NVIDIA/modelopt-torch-sparsity-codeowners modelopt/torch/speculative @NVIDIA/modelopt-torch-speculative-codeowners diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 70bae3609a..c1895d9433 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,17 +24,17 @@ repos: hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - # See: commit hooks modifies block_config.py leading to test_compress.py failing (#25) · Issues · omniml / modelopt · GitLab + # See: commit hooks modifies block_config.py leading to test_puzzletron.py failing (#25) · Issues · omniml / modelopt · GitLab exclude: > (?x)^( - modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| - modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py + modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config\.py| + modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py )$ - id: ruff-format exclude: > (?x)^( - modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| - modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py + modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config\.py| + modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py )$ - repo: https://github.com/pre-commit/mirrors-mypy @@ -107,7 +107,7 @@ repos: examples/speculative_decoding/main.py| examples/speculative_decoding/medusa_utils.py| examples/speculative_decoding/server_generate.py| - modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py| + modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py| )$ # Default hook for Apache 2.0 in c/c++/cuda files diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 2c2a7c785d..9792f2932c 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -7,6 +7,7 @@ Pruning can involve removal (prune) of Linear and Conv layers; and Transformer a This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model: 1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM or NeMo framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model. +1. [Puzzletron](../puzzletron/README.md): An advanced pruning method by NVIDIA using Mixed Integer Programming (MIP) based NAS search algorithm. 1. FastNAS: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints. 1. GradNAS: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints. @@ -23,8 +24,6 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar -For more advanced pruning strategies, such as the [Puzzle methodology](https://arxiv.org/pdf/2411.19146), please see [Puzzle pruning example](../compress/README.md). - ## Pre-Requisites For Minitron pruning for Megatron-LM / NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.09`) which has all the dependencies installed. diff --git a/examples/compress/README.md b/examples/puzzletron/README.md similarity index 86% rename from examples/compress/README.md rename to examples/puzzletron/README.md index 42e55892e5..e3a909d224 100644 --- a/examples/compress/README.md +++ b/examples/puzzletron/README.md @@ -1,6 +1,6 @@ -# Compress Algorithm Tutorial +# Puzzletron Algorithm Tutorial -This tutorial demonstrates how to compress large language models using the compress algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). +This tutorial demonstrates how to compress large language models using the puzzletron algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). The goal of the algorithm it to find the most optimal modifications to MLP and attention layers of the model, resulting in a heterogeneous model architecture. The supported modifications are: @@ -16,7 +16,7 @@ In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/ - Install Model-Optimizer in editable mode with the corresponding dependencies: ```bash -pip install -e .[hf,compress] +pip install -e .[hf,puzzletron] ``` - For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use s single GPU. @@ -34,7 +34,7 @@ hf auth login dataset split: "code", "math", "stem", "chat", excluding reasoning samples (2.62GB) ```bash - python -m modelopt.torch._compress.dataset.prepare_dataset --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 --output_dir path/to/Nemotron-Post-Training-Dataset-v2 + python -m modelopt.torch.puzzletron.dataset.prepare_dataset --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 --output_dir path/to/Nemotron-Post-Training-Dataset-v2 ``` 2. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file. @@ -51,23 +51,23 @@ hf auth login We can also set the target size of the resulting model using `num_params = 7_000_000_000`. This will be used as an upper bound for the number of parameters of the model. -3. Run the compression script. +3. Run the puzzletron pipeline. ```bash - torchrun --nproc_per_node 2 examples/compress/main.py --config examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Compress Progress" + torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Puzzletron Progress" ``` This will save the full output to `log.txt` and display the following progress on screen: ```bash - [2025-11-02 12:06:34][rank-0][main.py:71] Compress Progress 1/8: starting compression pipeline - [2025-11-02 12:06:45][rank-0][compress_nas_plugin.py:123] Compress Progress 2/8: converting model from HF to DeciLM (single-gpu) - [2025-11-02 12:07:07][rank-0][compress_nas_plugin.py:132] Compress Progress 3/8: scoring pruning activations (multi-gpu) - [2025-11-02 12:11:36][rank-0][compress_nas_plugin.py:137] Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu) - [2025-11-02 12:12:20][rank-0][compress_nas_plugin.py:217] Compress Progress 5/8: building replacement library and subblock statistics (single-gpu) - [2025-11-02 12:12:21][rank-0][compress_nas_plugin.py:222] Compress Progress 6/8: calculating one block scores (multi-gpu) - [2025-11-02 12:50:41][rank-0][compress_nas_plugin.py:226] Compress Progress 7/8: running MIP and realizing models (multi-gpu) - [2025-11-02 12:52:34][rank-0][main.py:115] Compress Progress 8/8: compression pipeline completed (multi-gpu) + [2025-11-02 12:06:34][rank-0][main.py:71] Puzzletron Progress 1/8: starting puzzletron pipeline + [2025-11-02 12:06:45][rank-0][puzzletron_nas_plugin.py:123] Puzzletron Progress 2/8: converting model from HF to DeciLM (single-gpu) + [2025-11-02 12:07:07][rank-0][puzzletron_nas_plugin.py:132] Puzzletron Progress 3/8: scoring pruning activations (multi-gpu) + [2025-11-02 12:11:36][rank-0][puzzletron_nas_plugin.py:137] Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu) + [2025-11-02 12:12:20][rank-0][puzzletron_nas_plugin.py:217] Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu) + [2025-11-02 12:12:21][rank-0][puzzletron_nas_plugin.py:222] Puzzletron Progress 6/8: calculating one block scores (multi-gpu) + [2025-11-02 12:50:41][rank-0][puzzletron_nas_plugin.py:226] Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu) + [2025-11-02 12:52:34][rank-0][main.py:115] Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu) ``` Once the process is complete, the resulting network architecture will be recorded in `log.txt` for your review: @@ -132,7 +132,7 @@ This assumes pruning, replacement library building, NAS scoring, and subblock st For example, let's set `target_memory: 96_000` in `llama-3_1-8B_pruneffn_memory.yaml`. ```bash -torchrun --nproc_per_node 2 examples/compress/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Compress Progress" +torchrun --nproc_per_node 2 examples/puzzletron/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress" ``` This will generate the following network architecture (see `log.txt`): diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml similarity index 100% rename from examples/compress/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml rename to examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml similarity index 72% rename from examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml rename to examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml index c9a0cabf30..20eec970e7 100644 --- a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml @@ -8,14 +8,14 @@ input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct # Dataset path for pruning and NAS scoring dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 -# Working directory for compression outputs +# Working directory for puzzletron outputs puzzle_dir: /workspace/puzzle_dir -# MIP memory constraint (in MiB) +# MIP memory constraint (in MiB) mip: human_constraints: target_memory: 78_000 # 78 GiB # FFN intermediate sizes to search over (heterogeneous architecture) pruning: - intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 + intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml similarity index 100% rename from examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml rename to examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml similarity index 100% rename from examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml rename to examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml similarity index 100% rename from examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml rename to examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml similarity index 100% rename from examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml rename to examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml similarity index 80% rename from examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml rename to examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml index 202af6eb02..ce1749d969 100644 --- a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -14,4 +14,4 @@ write_results: false calc_losses_on_cpu: false activations_log_dir: model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch._compress.utils.data.dataloaders.load_from_disk_fn} +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml similarity index 100% rename from examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml rename to examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml diff --git a/examples/compress/main.py b/examples/puzzletron/main.py similarity index 78% rename from examples/compress/main.py rename to examples/puzzletron/main.py index 2c3343c374..16d4de385e 100644 --- a/examples/compress/main.py +++ b/examples/puzzletron/main.py @@ -14,14 +14,14 @@ # limitations under the License. """ -Main script for running the compress algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146). +Main script for running the puzzletron algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146). This script provides two modes: -1. Default mode: Runs the full compression pipeline +1. Default mode: Runs the full puzzletron pipeline 2. MIP-only mode: Runs only the MIP search and realize models phase Usage: - # Full compression pipeline + # Full puzzletron pipeline torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml # Only MIP search and realize models phase @@ -32,21 +32,21 @@ from datetime import timedelta from pathlib import Path -import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.nas as mtn +import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.tools.hydra_utils import ( +from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel +from modelopt.torch.puzzletron.tools.hydra_utils import ( initialize_hydra_config_for_dir, register_hydra_resolvers, ) -from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch.puzzletron.tools.logger import mprint def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( - description="Compress large language models using the Compress algorithm (based on Puzzle paper https://arxiv.org/abs/2411.19146)" + description="Compress large language models using the Puzzletron algorithm (based on Puzzle paper https://arxiv.org/abs/2411.19146)" ) parser.add_argument( "--config", @@ -63,13 +63,13 @@ def parse_args(): return parser.parse_args() -def run_full_compress(hydra_config_path: str): - """Run the full compression pipeline. +def run_full_puzzletron(hydra_config_path: str): + """Run the full puzzletron pipeline. Args: config_path: Path to the YAML configuration file """ - mprint("Compress Progress 1/8: starting compression pipeline") + mprint("Puzzletron Progress 1/8: starting puzzletron pipeline") dist.setup(timeout=timedelta(10)) # Register Hydra custom resolvers (needed for config resolution) @@ -88,12 +88,12 @@ def run_full_compress(hydra_config_path: str): # Convert model (convert from HF to DeciLM, score pruning activations, # prune the model and save pruned checkpoints) - input_model = CompressModel() + input_model = PuzzletronModel() converted_model = mtn.convert( input_model, mode=[ ( - "compress", + "puzzletron", { "puzzle_dir": str(hydra_cfg.puzzle_dir), "input_model_path": hydra_cfg.input_hf_model_path, @@ -115,7 +115,7 @@ def run_full_compress(hydra_config_path: str): ) dist.cleanup() - mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") + mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") def run_mip_only(hydra_config_path: str): @@ -144,12 +144,12 @@ def run_mip_only(hydra_config_path: str): ) # mip_and_realize_models (distributed processing) - # TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API - mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") + # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API + mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) dist.cleanup() - mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") + mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") def main(): @@ -158,7 +158,7 @@ def main(): if args.mip_only: run_mip_only(hydra_config_path=args.config) else: - run_full_compress(hydra_config_path=args.config) + run_full_puzzletron(hydra_config_path=args.config) if __name__ == "__main__": diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py index bfc9b9290b..56436acfdd 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -26,8 +26,8 @@ from torch import nn import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.tools.logger import aprint -from modelopt.torch._compress.tools.robust_json import json_dump +from modelopt.torch.puzzletron.tools.logger import aprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump __all__ = [ "ForwardHook", diff --git a/modelopt/torch/_compress/README.md b/modelopt/torch/puzzletron/README.md similarity index 100% rename from modelopt/torch/_compress/README.md rename to modelopt/torch/puzzletron/README.md diff --git a/modelopt/torch/_compress/__init__.py b/modelopt/torch/puzzletron/__init__.py similarity index 100% rename from modelopt/torch/_compress/__init__.py rename to modelopt/torch/puzzletron/__init__.py diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py similarity index 100% rename from modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py rename to modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py similarity index 97% rename from modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py rename to modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py index 931ac762f5..ab7eed2ac3 100644 --- a/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py @@ -15,11 +15,11 @@ # mypy: ignore-errors """Provides a function to register activation hooks for a model. -Activation hooks are used to compute activation scores for pruning.""" +Activation hooks are used to compute activation scores for pruning. +""" import re -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( ForwardHook, IndependentChannelContributionHook, @@ -27,6 +27,7 @@ IterativeChannelContributionHook, LayerNormContributionHook, ) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM def register_activation_hooks( diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py similarity index 92% rename from modelopt/torch/_compress/activation_scoring/score_pruning_activations.py rename to modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py index f271a5f4f9..ef5e5e9ad2 100644 --- a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py @@ -19,13 +19,12 @@ from omegaconf import DictConfig import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.validate_model import validate_model +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.validate_model import validate_model def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: - """ - Determine if the activation hook method has proper checkpoint support implemented. + """Determine if the activation hook method has proper checkpoint support implemented. Args: activation_hooks_kwargs: Hook configuration @@ -47,8 +46,7 @@ def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: def check_scoring_completion(activations_log_dir: str, activation_hooks_kwargs=None) -> bool: - """ - Check if scoring is already completed by looking for the expected output files. + """Check if scoring is already completed by looking for the expected output files. Also checks if the scoring method is safe for resume. Args: @@ -89,8 +87,7 @@ def check_scoring_completion(activations_log_dir: str, activation_hooks_kwargs=N def should_skip_scoring_completely(cfg: DictConfig) -> bool: - """ - Determine if we should skip scoring entirely (only if 100% complete). + """Determine if we should skip scoring entirely (only if 100% complete). Partial progress should proceed to validate_model for proper resume. Args: diff --git a/modelopt/torch/_compress/build_library_and_stats.py b/modelopt/torch/puzzletron/build_library_and_stats.py similarity index 78% rename from modelopt/torch/_compress/build_library_and_stats.py rename to modelopt/torch/puzzletron/build_library_and_stats.py index 28e0f386c2..5f04f60494 100644 --- a/modelopt/torch/_compress/build_library_and_stats.py +++ b/modelopt/torch/puzzletron/build_library_and_stats.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Unified command that runs build_replacement_library followed by calc_subblock_stats. +"""Unified command that runs build_replacement_library followed by calc_subblock_stats. This script combines the functionality of both commands into a single workflow: 1. First, it builds the replacement library for the puzzle @@ -23,27 +22,23 @@ Usage: - python modelopt.torch._compress.build_library_and_stats.py --config-dir configs --config-name Llama-3_1-8B puzzle_dir=/path/to/puzzle/dir dataset_path=/path/to/dataset + python modelopt.torch.puzzletron.build_library_and_stats.py --config-dir configs --config-name Llama-3_1-8B puzzle_dir=/path/to/puzzle/dir dataset_path=/path/to/dataset The script uses the same Hydra configuration as the individual commands and supports all the same configuration parameters for both build_replacement_library and calc_subblock_stats. """ -import hydra from omegaconf import DictConfig -from modelopt.torch._compress.replacement_library.build_replacement_library import ( +from modelopt.torch.puzzletron.replacement_library.build_replacement_library import ( launch_build_replacement_library, ) -from modelopt.torch._compress.subblock_stats.calc_subblock_stats import launch_calc_subblock_stats -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.utils.parsing import format_global_config +from modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats import launch_calc_subblock_stats +from modelopt.torch.puzzletron.tools.logger import mprint def launch_build_library_and_stats(cfg: DictConfig) -> None: - """ - Launch both build_replacement_library and calc_subblock_stats in sequence. + """Launch both build_replacement_library and calc_subblock_stats in sequence. Args: cfg: Hydra configuration containing settings for both commands diff --git a/modelopt/torch/_compress/dataset/__init__.py b/modelopt/torch/puzzletron/dataset/__init__.py similarity index 100% rename from modelopt/torch/_compress/dataset/__init__.py rename to modelopt/torch/puzzletron/dataset/__init__.py diff --git a/modelopt/torch/_compress/dataset/prepare_dataset.py b/modelopt/torch/puzzletron/dataset/prepare_dataset.py similarity index 97% rename from modelopt/torch/_compress/dataset/prepare_dataset.py rename to modelopt/torch/puzzletron/dataset/prepare_dataset.py index 072640777a..6f1749697c 100644 --- a/modelopt/torch/_compress/dataset/prepare_dataset.py +++ b/modelopt/torch/puzzletron/dataset/prepare_dataset.py @@ -19,7 +19,7 @@ import fire import numpy as np -from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch.puzzletron.tools.logger import mprint def process_and_save_dataset( diff --git a/modelopt/torch/_compress/decilm/conversion_utils.py b/modelopt/torch/puzzletron/decilm/conversion_utils.py similarity index 100% rename from modelopt/torch/_compress/decilm/conversion_utils.py rename to modelopt/torch/puzzletron/decilm/conversion_utils.py diff --git a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/puzzletron/decilm/converters/convert_llama3_to_decilm.py similarity index 93% rename from modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py rename to modelopt/torch/puzzletron/decilm/converters/convert_llama3_to_decilm.py index 4df9f009a6..c5f107ea1e 100644 --- a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py +++ b/modelopt/torch/puzzletron/decilm/converters/convert_llama3_to_decilm.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Convert a Llama3 model to a DeciLM model.""" +"""Convert a Llama3 model to a DeciLM model.""" #!/usr/bin/env python3 from pathlib import Path @@ -23,10 +22,10 @@ from fire import Fire from transformers import LlamaConfig -from modelopt.torch._compress.decilm.conversion_utils import convert_model_weights_to_decilm -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.tools.checkpoint_utils import copy_tokenizer -from modelopt.torch._compress.tools.checkpoint_utils_hf import copy_deci_lm_hf_code +from modelopt.torch.puzzletron.decilm.conversion_utils import convert_model_weights_to_decilm +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import copy_deci_lm_hf_code """ example: diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/__init__.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/__init__.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/tokenization_decilm.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/tokenization_decilm.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py similarity index 100% rename from modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py rename to modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py diff --git a/modelopt/torch/_compress/mip/mip_and_realize_models.py b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py similarity index 89% rename from modelopt/torch/_compress/mip/mip_and_realize_models.py rename to modelopt/torch/puzzletron/mip/mip_and_realize_models.py index a3a1a84b91..e241021ec9 100644 --- a/modelopt/torch/_compress/mip/mip_and_realize_models.py +++ b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py @@ -17,20 +17,19 @@ # mypy: ignore-errors from pathlib import Path -from typing import List import torch from omegaconf import DictConfig import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.mip.run_puzzle import run_puzzle -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import ( +from modelopt.torch.puzzletron.mip.run_puzzle import run_puzzle +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.validate_puzzle_with_multi_replacements import ( validate_puzzle_solutions, ) -def launch_mip(cfg: DictConfig) -> List[str]: +def launch_mip(cfg: DictConfig) -> list[str]: solution_paths = run_puzzle(args=cfg.mip) return solution_paths diff --git a/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py b/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py similarity index 95% rename from modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py rename to modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py index 438db3312e..5b4eccbc15 100644 --- a/modelopt/torch/_compress/mip/mip_with_multi_layer_replacements.py +++ b/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py @@ -19,14 +19,14 @@ import math import warnings from collections import defaultdict +from collections.abc import Hashable, Iterable from copy import deepcopy from random import random -from typing import Any, Hashable, Iterable, Optional, TypeAlias +from typing import Any, TypeAlias from mip import BINARY, Model, maximize, minimize, xsum -from modelopt.torch._compress.mip.utils import ( - InfeasibleError, +from modelopt.torch.puzzletron.mip.utils import ( consecutive_ngrams, get_nested_key, sort_replacements, @@ -42,7 +42,7 @@ def run_mip( objective: str, constraints: dict[str, float], bigger_is_better: bool, - max_seconds_per_solution: Optional[float] = None, + max_seconds_per_solution: float | None = None, ) -> tuple[ChosenReplacements, float, dict[str, float]]: orig_num_replacements = len(replacements) replacements = { @@ -60,7 +60,7 @@ def run_mip( mip_model = Model() objective_vars = [] - constraint_vars = {constraint_key: [] for constraint_key in constraints.keys()} + constraint_vars = {constraint_key: [] for constraint_key in constraints} choice_indicators_by_layer = defaultdict(list) for replacement_id, replacement in replacements.items(): is_chosen = mip_model.add_var(var_type=BINARY) @@ -71,7 +71,7 @@ def run_mip( objective_vars.append(is_chosen * get_nested_key(replacement, objective)) - for constraint_key in constraints.keys(): + for constraint_key in constraints: constraint_vars[constraint_key].append( is_chosen * get_nested_key(replacement, constraint_key) ) @@ -107,7 +107,7 @@ def run_mip( # Trust But Verify: calculate total value and costs, and check that all the constraints are filled total_value = 0.0 - total_costs = {constraint_key: 0 for constraint_key in constraints.keys()} + total_costs = dict.fromkeys(constraints.keys(), 0) chosen_replacements: ChosenReplacements = [] chosen_layers = [] for replacement_id, replacement in replacements.items(): @@ -116,7 +116,7 @@ def run_mip( assert replacement not in chosen_replacements chosen_replacements.append(replacement) total_value += get_nested_key(replacement, objective) - for constraint_key in constraints.keys(): + for constraint_key in constraints: total_costs[constraint_key] += get_nested_key(replacement, constraint_key) for parent_layer_idx in replacement["parent_layer_indices"]: assert parent_layer_idx not in chosen_layers diff --git a/modelopt/torch/_compress/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py similarity index 96% rename from modelopt/torch/_compress/mip/run_puzzle.py rename to modelopt/torch/puzzletron/mip/run_puzzle.py index 4868479e23..72919d27cd 100644 --- a/modelopt/torch/_compress/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -20,32 +20,33 @@ import dataclasses import enum import json +from collections.abc import Hashable, Iterable from copy import deepcopy from pathlib import Path -from typing import Any, Hashable, Iterable, List, Literal, TypeAlias +from typing import Any, Literal, TypeAlias import numpy as np import yaml from omegaconf import DictConfig, ListConfig, OmegaConf -from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, BlockConfig, FFNConfig, ) -from modelopt.torch._compress.mip.mip_with_multi_layer_replacements import ( +from modelopt.torch.puzzletron.mip.mip_with_multi_layer_replacements import ( run_mip as run_multi_layer_replacement_mip, ) -from modelopt.torch._compress.replacement_library.replacement_utils import ( +from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( extract_block_configs_and_locations, parse_layer_replacement, replacement_is_teacher, ) -from modelopt.torch._compress.tools.checkpoint_utils import load_model_config -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.robust_json import json_dump -from modelopt.torch._compress.utils.parsing import get_nested_key, parse_json, parse_path -from modelopt.torch._compress.utils.utils import block_config_to_str, solution_to_str +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.puzzletron.utils.parsing import get_nested_key, parse_json, parse_path +from modelopt.torch.puzzletron.utils.utils import block_config_to_str, solution_to_str """ Usage: @@ -418,7 +419,7 @@ def _assert_valid_config(args, puzzle_profile): exit(1) -def _get_minimal_unique_names(dicts: List[dict]) -> List[str]: +def _get_minimal_unique_names(dicts: list[dict]) -> list[str]: all_keys = set(k for d in dicts for k in d.keys()) all_values = {k: set(d[k] for d in dicts if k in d) for k in all_keys} non_common_keys = [k for k, values in all_values.items() if len(values) > 1] @@ -426,7 +427,7 @@ def _get_minimal_unique_names(dicts: List[dict]) -> List[str]: return ["-".join(f"{k}_{d[k]}".replace(".", "_") for k in non_common_keys) for d in dicts] -def run_puzzle(args: argparse.Namespace | DictConfig) -> List[str]: +def run_puzzle(args: argparse.Namespace | DictConfig) -> list[str]: # Loads config from args/puzzle_profile if args.puzzle_profile is not None: with open(args.puzzle_profile) as f: @@ -578,9 +579,7 @@ def _parse_teacher_block_metrics( "block_idx": block_idx, "parent_layer_indices": [block_idx], "metrics": { - **{ - metric_name: 0.0 for metric_name in all_metric_names - }, # default value 0. for teacher + **dict.fromkeys(all_metric_names, 0.0), # default value 0. for teacher **_extract_average_metrics(raw_metrics), # override with real value if exists }, **( @@ -597,7 +596,7 @@ def _parse_teacher_block_metrics( def _extract_average_metrics(raw_metrics: dict[str, dict]) -> dict[str, float]: average_metrics = dict() - for metric_name in raw_metrics.keys(): + for metric_name in raw_metrics: metric_dict = raw_metrics[metric_name] if isinstance(metric_dict, dict) and ("avg" in metric_dict.keys()): metric_value = raw_metrics[metric_name]["avg"] diff --git a/modelopt/torch/_compress/mip/utils.py b/modelopt/torch/puzzletron/mip/utils.py similarity index 100% rename from modelopt/torch/_compress/mip/utils.py rename to modelopt/torch/puzzletron/mip/utils.py diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py similarity index 71% rename from modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py rename to modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py index 55b9d10b0f..5e1eace934 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py @@ -13,30 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Compress NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). +"""Puzzletron NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). It is used by mtn.convert() to convert a model from HF format to DeciLM format + do pruning scoring and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search. """ -import datetime from pathlib import Path -import torch from torch import nn -import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models -import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts -import modelopt.torch._compress.scoring.scoring as scoring +import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.puzzletron.scoring.scoring as scoring import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress import build_library_and_stats -from modelopt.torch._compress.activation_scoring import score_pruning_activations -from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( - convert_llama3_to_decilm, -) -from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir -from modelopt.torch._compress.tools.logger import mprint from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.mode import ( @@ -47,14 +37,21 @@ RestoreEntrypoint, ) from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict +from modelopt.torch.puzzletron import build_library_and_stats +from modelopt.torch.puzzletron.activation_scoring import score_pruning_activations +from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( + convert_llama3_to_decilm, +) +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir +from modelopt.torch.puzzletron.tools.logger import mprint -class CompressModel(nn.Module): - pass # No model implementation is needed for the compress mode +class PuzzletronModel(nn.Module): + pass # No model implementation is needed for the puzzletron mode -class CompressConfig(ModeloptBaseConfig): - """Configuration for Compress NAS algorithm.""" +class PuzzletronConfig(ModeloptBaseConfig): + """Configuration for Puzzletron NAS algorithm.""" # Input model path to compress in the HF format input_model_path: str = ModeloptField( @@ -92,7 +89,7 @@ class CompressConfig(ModeloptBaseConfig): ) -def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertReturnType: +def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: """1. Convert the model from HF format to DeciLM format. 2. Score the pruning activations. 3. Prune the model and save pruned checkpoints @@ -118,7 +115,7 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR # Convert Llama3 model to DeciLM model # TODO: Make it generic, do not call convert_llama3_to_decilm directly. if dist.is_master(): - mprint("Compress Progress 2/8: converting model from HF to DeciLM (single-gpu)") + mprint("Puzzletron Progress 2/8: converting model from HF to DeciLM (single-gpu)") hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable convert_llama3_to_decilm( input_dir=config.input_model_path, @@ -127,13 +124,13 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR dist.barrier() # Score_pruning_activations (distributed processing) - mprint("Compress Progress 3/8: scoring pruning activations (multi-gpu)") + mprint("Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)") score_pruning_activations.launch_score_activations(hydra_cfg) # Prune the model and save pruned checkpoints if dist.is_master(): mprint( - "Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" + "Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" ) pruning_ckpts.launch_prune_ckpt(hydra_cfg) dist.barrier() @@ -141,58 +138,57 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR return model, {} -def restore_compress_model( - model: nn.Module, config: CompressConfig, metadata: MetadataDict +def restore_puzzletron_model( + model: nn.Module, config: PuzzletronConfig, metadata: MetadataDict ) -> nn.Module: - """Restore is not needed for the compress mode as we are not saving any model state""" + """Restore is not needed for the puzzletron mode as we are not saving any model state""" return model @NASModeRegistry.register_mode -class CompressDescriptor(ModeDescriptor): - """Descriptor for the Compress mode.""" +class PuzzletronDescriptor(ModeDescriptor): + """Descriptor for the Puzzletron mode.""" @property def name(self) -> str: """String identifier for this mode.""" - return "compress" + return "puzzletron" @property def config_class(self) -> type[ModeloptBaseConfig]: """Configuration class for this mode.""" - return CompressConfig + return PuzzletronConfig @property def search_algorithm(self) -> type[BaseSearcher]: """Return the associated searcher implementation.""" - - return CompressSearcher + return PuzzletronSearcher @property def convert(self) -> ConvertEntrypoint: """Entrypoint to convert a model.""" - return convert_compress_model + return convert_puzzletron_model @property def restore(self) -> RestoreEntrypoint: """Entrypoint to restore a model.""" - return restore_compress_model + return restore_puzzletron_model @property def export_mode(self) -> str | None: """The mode that corresponds to the export mode. For now, this will be a no-op as there is no modelopt's concept of search space defined - for the compress algorithm. + for the puzzletron algorithm. """ return "export_nas" -class CompressSearcher(BaseSearcher): - """Runs NAS search for the Compress mode.""" +class PuzzletronSearcher(BaseSearcher): + """Runs NAS search for the Puzzletron mode.""" @property def default_state_dict(self) -> SearchStateDict: - """Not needed for the compress mode as we are not saving any model state""" + """Not needed for the puzzletron mode as we are not saving any model state""" return {} def run_search(self) -> None: @@ -209,15 +205,15 @@ def run_search(self) -> None: # Build_library_and_stats (single process) if dist.is_master(): mprint( - "Compress Progress 5/8: building replacement library and subblock statistics (single-gpu)" + "Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)" ) build_library_and_stats.launch_build_library_and_stats(hydra_cfg) dist.barrier() # Calc_one_block_scores (distributed processing) - mprint("Compress Progress 6/8: calculating one block scores (multi-gpu)") + mprint("Puzzletron Progress 6/8: calculating one block scores (multi-gpu)") scoring.launch_scoring(hydra_cfg) # mip_and_realize_models (distributed processing) - mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") + mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/_compress/pruning/pruning_ckpts.py b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py similarity index 91% rename from modelopt/torch/_compress/pruning/pruning_ckpts.py rename to modelopt/torch/puzzletron/pruning/pruning_ckpts.py index b413a3f783..5a0dfed01d 100644 --- a/modelopt/torch/_compress/pruning/pruning_ckpts.py +++ b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py @@ -23,28 +23,24 @@ import json import os import time -from typing import Optional -import hydra from omegaconf import DictConfig -from modelopt.torch._compress.tools.bypassed_training.child_init import ( +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( GQAInitMode, HiddenSizeInitMode, LinearInitMode, MlpInitMode, ) -from modelopt.torch._compress.tools.bypassed_training.init_child_from_parent import ( +from modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent import ( init_child_from_parent, ) -from modelopt.torch._compress.tools.checkpoint_utils import load_model_config -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.validate_model import validate_model +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config +from modelopt.torch.puzzletron.tools.logger import mprint def launch_ffn_intermediates_prune_ckpt( - cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None + cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None ): for intermediate_size in cfg.pruning.intermediate_size_list: dirname = f"ffn_{intermediate_size}_attn_no_op" @@ -87,7 +83,7 @@ def launch_ffn_intermediates_prune_ckpt( def launch_attn_groups_prune_ckpt( - cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None + cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None ): for n_heads_in_group in cfg.pruning.n_heads_in_group_list: dirname = f"n_heads_in_group{n_heads_in_group}" @@ -154,17 +150,17 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): else: intermediate_sizes.append(None) - mprint(f"Teacher config:") + mprint("Teacher config:") mprint(f" - hidden_size: {parent_hidden_size}") mprint(f" - intermediate_sizes: {intermediate_sizes}") os.makedirs(os.path.join(cfg.puzzle_dir, "ckpts"), exist_ok=True) for hidden_size in cfg.pruning.hidden_size_list: - mprint(f"\n######################################################################") + mprint("\n######################################################################") mprint(f"Hidden Size = {hidden_size}") - mprint(f"######################################################################\n") + mprint("######################################################################\n") - mprint(f"Child config:") + mprint("Child config:") mprint(f" - hidden_size: {hidden_size}") # Create model config overrides with proper FFN configuration @@ -208,9 +204,9 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): def launch_experts_prune_ckpt( cfg: DictConfig, - max_save_workers: Optional[int] = None, - max_layer_workers: Optional[int] = None, - symlink_suffix: Optional[str] = None, + max_save_workers: int | None = None, + max_layer_workers: int | None = None, + symlink_suffix: str | None = None, ): for num_experts in cfg.pruning.num_experts_to_keep_list: dirname = f"num_experts_{num_experts}" @@ -256,7 +252,7 @@ def launch_experts_prune_ckpt( def launch_moe_ffn_intermediates_prune_ckpt( - cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None + cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None ): for intermediate_size in cfg.pruning.intermediate_size_list: dirname = f"moe_ffn_{intermediate_size}_attn_no_op" @@ -312,14 +308,14 @@ def launch_prune_ckpt(cfg: DictConfig): max_layer_workers = int(os.environ["PRUNING_LAYER_WORKERS"]) # Log optimization settings (extracted from individual pruning methods) - mprint(f"Optimization Settings:") + mprint("Optimization Settings:") mprint( f" - I/O workers (max_workers): {'auto-calculate' if max_save_workers is None else max_save_workers}" ) mprint( f" - Layer workers (max_layer_workers): {'auto-calculate' if max_layer_workers is None else max_layer_workers}" ) - mprint(f" (Override with env vars: PRUNING_IO_WORKERS, PRUNING_LAYER_WORKERS)") + mprint(" (Override with env vars: PRUNING_IO_WORKERS, PRUNING_LAYER_WORKERS)") if target_layer == "mlp.down_proj": launch_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) @@ -331,7 +327,7 @@ def launch_prune_ckpt(cfg: DictConfig): # Check if we should use symlink suffix for chained pruning symlink_suffix = getattr(cfg.pruning, "symlink_suffix", None) launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers, symlink_suffix) - elif target_layer == "regex:experts\.\d+\.down_proj$": + elif target_layer == r"regex:experts\.\d+\.down_proj$": launch_moe_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) else: raise NotImplementedError( diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/puzzletron/puzzletron.py similarity index 74% rename from modelopt/torch/_compress/compress.py rename to modelopt/torch/puzzletron/puzzletron.py index 21e9df2af0..1051fdbaf7 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -13,28 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" - -This module provides the main compression function for a model -using MIP-based NAS search algorithm. - -""" +"""This module provides the main compression function for a model using MIP-based NAS search algorithm.""" from omegaconf import DictConfig -import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations -import modelopt.torch._compress.build_library_and_stats as build_library_and_stats -import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models -import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts -import modelopt.torch._compress.scoring.scoring as scoring +import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch.puzzletron.build_library_and_stats as build_library_and_stats +import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.puzzletron.scoring.scoring as scoring import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir -def compress( +def puzzletron( hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str ) -> DictConfig: - """Compress a puzzletron model using the MIP-based NAS search algorithm. + """Compress a model using the MIP-based NAS search algorithm from Puzzletron. Args: hydra_config_dir (str): path to a hydra_config_dir that defines the search space @@ -45,7 +40,7 @@ def compress( Returns: Hydra config object after compressing the model. The same hydra configuration object is used across all compression steps. - @TODO: Investigate if this config object is immutable across steps and clarify + TODO: Investigate if this config object is immutable across steps and clarify """ # Step 0: Load puzzletron hydra config hydra_cfg = initialize_hydra_config_for_dir( diff --git a/modelopt/torch/_compress/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py similarity index 93% rename from modelopt/torch/_compress/replacement_library/build_replacement_library.py rename to modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index 760952a609..1618aceaf3 100644 --- a/modelopt/torch/_compress/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -12,56 +12,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -This module constructs the replacement library JSON files from a puzzle directory containing +"""This module constructs the replacement library JSON files from a puzzle directory containing multiple trained model checkpoints. It analyzes checkpoints to extract unique block and subblock configurations, builds a library of available replacements, and generates solutions for layer replacement in compressed models. The resulting replacement library can then be used by ReplacementLibrary to efficiently load models with mixed teacher/student layers. - -Standard Puzzle Usage: -====================== -python -m modelopt.torch._compress.replacement_library.build_replacement_library PUZZLE_DIR - -Teacher checkpoint dir is assumed to be inside PUZZLE_DIR/ckpts/teacher (symlink is recommended) -though you can supply an explicit --teacher_checkpoint_dir. - ---add_ffn_no_ops and --add_attention_no_ops are optional (default True), - - -Untrained puzzle run (with bypass): -=================================== -The subblock that doesn't interest you in the checkpoint should be no_op. - """ # mypy: ignore-errors import json from pathlib import Path -from typing import Any, Type +from typing import Any import pandas as pd from omegaconf import DictConfig -from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, BlockConfig, FFNConfig, ) -from modelopt.torch._compress.replacement_library.replacement_utils import ( +from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( is_replacement_identical_to_teacher, replacement_is_teacher, sort_replacements, ) -from modelopt.torch._compress.tools.checkpoint_utils import ( +from modelopt.torch.puzzletron.tools.checkpoint_utils import ( SAFETENSORS_SUBBLOCKS_DIR_NAME, is_valid_decilm_checkpoint, load_model_config, ) -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.robust_json import json_dump -from modelopt.torch._compress.utils.parsing import format_global_config -from modelopt.torch._compress.utils.utils import block_config_to_str, subblock_config_to_str +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.puzzletron.utils.parsing import format_global_config +from modelopt.torch.puzzletron.utils.utils import block_config_to_str, subblock_config_to_str UNIQUE_SUBBLOCK_IDENTIFIER = ["block_config", "attention_config", "ffn_config", "block_idx"] CHECKPOINTS_DIR_NAME = "ckpts" @@ -73,8 +57,7 @@ def build_replacement_library( add_ffn_no_ops: bool = True, add_attention_no_ops: bool = True, ) -> None: - """ - For normal puzzle runs, use default values. + """For normal puzzle runs, use default values. For advanced use cases, see the Usage section. """ master_puzzle_dir = Path(master_puzzle_dir) @@ -107,9 +90,7 @@ def build_replacement_library( def launch_build_replacement_library(cfg: DictConfig) -> None: - """ - Launch the build replacement library function with Hydra configuration. - """ + """Launch the build replacement library function with Hydra configuration.""" mprint(f"Building replacement library for puzzle directory: {cfg.puzzle_dir}") mprint(f"Teacher directory: {cfg.teacher_dir}") mprint( @@ -132,8 +113,8 @@ def infer_teacher_dir( teacher_checkpoint_dir = Path(master_puzzle_dir) / CHECKPOINTS_DIR_NAME / "teacher" if not teacher_checkpoint_dir.exists(): raise ValueError( - f"You must either provide the --teacher_checkpoint_dir argument, or create a link to the " - f"teacher dir under '{{PUZZLE_DIR}}/ckpts'." + "You must either provide the --teacher_checkpoint_dir argument, or create a link to the " + "teacher dir under '{PUZZLE_DIR}/ckpts'." ) teacher_checkpoint_dir = Path(teacher_checkpoint_dir).resolve().absolute() return teacher_checkpoint_dir @@ -381,7 +362,7 @@ def _add_no_op_subblock_rows( def _get_rows_with_no_op_subblock( subblocks_df: pd.DataFrame, no_op_subblock: str -) -> tuple[pd.DataFrame, Type[AttentionConfig] | Type[FFNConfig]]: +) -> tuple[pd.DataFrame, type[AttentionConfig] | type[FFNConfig]]: other_subblock = "ffn" if no_op_subblock == "attention" else "attention" subblock_cls = AttentionConfig if no_op_subblock == "attention" else FFNConfig no_op_subblock_config = subblock_cls(no_op=True) diff --git a/modelopt/torch/_compress/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py similarity index 96% rename from modelopt/torch/_compress/replacement_library/replacement_library.py rename to modelopt/torch/puzzletron/replacement_library/replacement_library.py index 5e2fee6f0d..bf6cc66362 100644 --- a/modelopt/torch/_compress/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -12,8 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Replacement library for efficiently loading and managing layer-replaced DeciLM models. +"""Replacement library for efficiently loading and managing layer-replaced DeciLM models. - Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations """ # mypy: ignore-errors @@ -21,7 +20,6 @@ import json import re from pathlib import Path -from typing import Optional import numpy as np import torch @@ -31,21 +29,21 @@ from torch import nn import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, DeciLMForCausalLM, DeciLMMultiDecoderLayer, DeciLMRMSNorm, LMHead, ) -from modelopt.torch._compress.replacement_library.replacement_utils import ( +from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( extract_block_configs_and_locations, parse_layer_replacement, sort_replacements, weights_path_to_checkpoint_dir, ) -from modelopt.torch._compress.tools.checkpoint_utils import ( +from modelopt.torch.puzzletron.tools.checkpoint_utils import ( PTH_SUBBLOCKS_DIR_NAME, SAFETENSORS_SUBBLOCKS_DIR_NAME, infer_weights_dtype, @@ -53,7 +51,7 @@ init_module_with_state_dict, load_model_config, ) -from modelopt.torch._compress.tools.sharded_checkpoint_utils import ( +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( create_dummy_model, is_in_safetensors_format, load_sharded_state_dict, @@ -64,7 +62,7 @@ class ReplacementLibrary: def __init__( self, replacement_library_path: str | Path, - model_config_overrides: Optional[dict] = None, + model_config_overrides: dict | None = None, ): self.replacement_library = self._load_replacement_library(replacement_library_path) self._ensure_all_checkpoints_are_split() @@ -223,7 +221,7 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: if len(state_dict) > 0: block_indices = [ int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0]) - for param_name in state_dict.keys() + for param_name in state_dict ] assert sorted(set(block_indices)) == list( range(min(block_indices), max(block_indices) + 1) @@ -318,7 +316,7 @@ def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor: partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name]) return partial_state_dict[param_name] - non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth" + non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / "non_block.pth" assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir) non_block_state_dict = torch.load(non_block_pth_path) return non_block_state_dict[param_name] diff --git a/modelopt/torch/_compress/replacement_library/replacement_utils.py b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py similarity index 91% rename from modelopt/torch/_compress/replacement_library/replacement_utils.py rename to modelopt/torch/puzzletron/replacement_library/replacement_utils.py index 331357d2bb..68ba0b5fc3 100644 --- a/modelopt/torch/_compress/replacement_library/replacement_utils.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py @@ -12,8 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -This module provides helper functions for parsing, sorting, and analyzing layer replacement +"""This module provides helper functions for parsing, sorting, and analyzing layer replacement configurations used in the replacement library for model compression. """ @@ -22,9 +21,9 @@ from copy import deepcopy from pathlib import Path -from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.mip.utils import sort_replacements +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.mip.utils import sort_replacements def parse_layer_replacement(layer_replacement: dict | str) -> dict: @@ -44,7 +43,7 @@ def parse_layer_replacement(layer_replacement: dict | str) -> dict: return layer_replacement -# sort_replacements moved to modelopt.torch._compress.mip.utils and imported above +# sort_replacements moved to modelopt.torch.puzzletron.mip.utils and imported above def extract_block_configs_and_locations( diff --git a/modelopt/torch/_compress/scoring/scoring.py b/modelopt/torch/puzzletron/scoring/scoring.py similarity index 91% rename from modelopt/torch/_compress/scoring/scoring.py rename to modelopt/torch/puzzletron/scoring/scoring.py index 5f745b3990..8f1871de89 100644 --- a/modelopt/torch/_compress/scoring/scoring.py +++ b/modelopt/torch/puzzletron/scoring/scoring.py @@ -19,18 +19,16 @@ import os import re from glob import glob -from pathlib import Path import hydra import numpy as np import pandas as pd -import torch from omegaconf import DictConfig import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import ( +from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.validate_puzzle_with_multi_replacements import ( validate_puzzle_solutions, ) diff --git a/modelopt/torch/_compress/sewing_kit/__init__.py b/modelopt/torch/puzzletron/sewing_kit/__init__.py similarity index 100% rename from modelopt/torch/_compress/sewing_kit/__init__.py rename to modelopt/torch/puzzletron/sewing_kit/__init__.py diff --git a/modelopt/torch/_compress/sewing_kit/core.py b/modelopt/torch/puzzletron/sewing_kit/core.py similarity index 99% rename from modelopt/torch/_compress/sewing_kit/core.py rename to modelopt/torch/puzzletron/sewing_kit/core.py index 8f926954b5..41eaeee75f 100644 --- a/modelopt/torch/_compress/sewing_kit/core.py +++ b/modelopt/torch/puzzletron/sewing_kit/core.py @@ -197,6 +197,8 @@ def output( @dataclass class ExternalTarget(TargetWithNamedInputs, TargetWithNamedOutputs, metaclass=Singleton): + """External target for stitched modules.""" + @override def __hash__(self) -> int: return super().__hash__() diff --git a/modelopt/torch/_compress/sewing_kit/passage/__init__.py b/modelopt/torch/puzzletron/sewing_kit/passage/__init__.py similarity index 100% rename from modelopt/torch/_compress/sewing_kit/passage/__init__.py rename to modelopt/torch/puzzletron/sewing_kit/passage/__init__.py diff --git a/modelopt/torch/_compress/sewing_kit/passage/core.py b/modelopt/torch/puzzletron/sewing_kit/passage/core.py similarity index 99% rename from modelopt/torch/_compress/sewing_kit/passage/core.py rename to modelopt/torch/puzzletron/sewing_kit/passage/core.py index 22c720b503..c0fcb4b123 100644 --- a/modelopt/torch/_compress/sewing_kit/passage/core.py +++ b/modelopt/torch/puzzletron/sewing_kit/passage/core.py @@ -36,6 +36,8 @@ @dataclass class InputArgs: + """Container for input arguments to modules.""" + args: list[Any] kwargs: dict[str, Any] diff --git a/modelopt/torch/_compress/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py similarity index 100% rename from modelopt/torch/_compress/sewing_kit/utils.py rename to modelopt/torch/puzzletron/sewing_kit/utils.py diff --git a/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py similarity index 95% rename from modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py rename to modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py index e25c8e38d4..2e8630bc98 100644 --- a/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py @@ -28,14 +28,14 @@ import numpy as np import torch -from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, FFNConfig, MambaConfig, ) -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMMoe -from modelopt.torch._compress.utils.utils import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMMoe +from modelopt.torch.puzzletron.utils.utils import ( calculate_kv_dim, raise_unknown_subblock_config_error, sizeof_dtype, @@ -117,7 +117,7 @@ def calc_subblock_active_params( def load_moe_stats(stats_file: str) -> dict: - with open(stats_file, "r") as f: + with open(stats_file) as f: stats = json.load(f) return [np.array(l) / np.sum(l) if len(l) > 0 else 0 for l in stats] @@ -178,10 +178,9 @@ def calculate_attention_memory( kv_cache_dtype: torch.dtype, allocate_prefill_query: bool, ) -> dict[str, float]: - """ - allocate_prefill_query: infery-llm style. - Infery used a unified Wqkv matrix, so before extracting the kv-cache, - the query also had to be kept in-memory, once per layer. + """allocate_prefill_query: infery-llm style. + Infery used a unified Wqkv matrix, so before extracting the kv-cache, + the query also had to be kept in-memory, once per layer. """ seq_len = prefill_seq_len + generation_seq_len if ( diff --git a/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py similarity index 93% rename from modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py rename to modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index 76e6c34281..07597eb5c0 100644 --- a/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -19,10 +19,11 @@ import dataclasses import json import os +from collections.abc import Iterable from functools import partial from itertools import product from pathlib import Path -from typing import Iterable, Optional, Type, TypeVar +from typing import TypeVar os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -32,38 +33,29 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from tqdm import tqdm -from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, BlockConfig, FFNConfig, SubblockConfig, ) -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement -from modelopt.torch._compress.subblock_stats.calc_subblock_params_and_memory import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement +from modelopt.torch.puzzletron.subblock_stats.calc_subblock_params_and_memory import ( calc_subblock_active_params, calculate_non_block_memory, calculate_non_block_params, calculate_subblock_memory, calculate_subblock_params, ) -from modelopt.torch._compress.tools.checkpoint_utils import load_model_config -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.robust_json import json_dump -from modelopt.torch._compress.utils.parsing import format_global_config +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.puzzletron.utils.parsing import format_global_config # Type variable for dataclasses T_DataClass = TypeVar("T_DataClass") -""" -Usage: -python -m modelopt.torch._compress.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --benchmark_iterations 1000 ] - ---benchmark_iterations=None (the default) means that the code won't use infery to benchmark runtime, - only memory stats will be calculated. If you want to benchmark runtime, run inside an infery-llm docker. - -""" - def calculate_subblock_stats( calc_subblock_stats_config: DictConfig, @@ -77,7 +69,7 @@ def calculate_subblock_stats( n_embd: int, n_head: int, vocab_size: int, - benchmark_iterations: Optional[int], + benchmark_iterations: int | None, use_cuda_graph: bool, weights_dtype: torch.dtype, activations_dtype: torch.dtype, @@ -189,7 +181,6 @@ def calculate_subblock_stats( ) if is_calc_runtime: - pass # TODO: fix # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \ @@ -215,9 +206,7 @@ def calculate_subblock_stats( def launch_calc_subblock_stats(cfg: DictConfig) -> None: - """ - Launch the calc subblock stats function with Hydra configuration. - """ + """Launch the calc subblock stats function with Hydra configuration.""" mprint(f"Calculating subblock stats for puzzle directory: {cfg.puzzle_dir}") mprint(f"Teacher directory: {cfg.teacher_dir}") mprint( @@ -456,7 +445,7 @@ def _load_subblock_configs_from_replacement_library( return subblock_configs -T_DataClass: TypeVar = Type[dataclasses.dataclass] +T_DataClass: TypeVar = type[dataclasses.dataclass] def _dataclass_from_dict( @@ -523,10 +512,7 @@ def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> di stats for stats in subblock_stats if all( - [ - stats["args"][key] == corresponding_bf16_args[key] - for key in corresponding_bf16_args.keys() - ] + [stats["args"][key] == corresponding_bf16_args[key] for key in corresponding_bf16_args] ) ] if len(matching_bf16_stats) == 0: diff --git a/modelopt/torch/_compress/tools/__init__.py b/modelopt/torch/puzzletron/tools/__init__.py similarity index 100% rename from modelopt/torch/_compress/tools/__init__.py rename to modelopt/torch/puzzletron/tools/__init__.py diff --git a/modelopt/torch/_compress/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py similarity index 95% rename from modelopt/torch/_compress/tools/bypassed_training/child_init.py rename to modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index 1bd36fa090..3981b62e34 100644 --- a/modelopt/torch/_compress/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -22,23 +22,24 @@ import os import re import time +from collections.abc import Callable from copy import deepcopy from enum import Enum from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any import torch from typeguard import check_type -from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( SUBBLOCK_CLS_DICT, BlockConfig, _get_dataclass_type, _is_dataclass_type, ) -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.tools.logger import aprint, mprint class GQAInitMode(Enum): @@ -92,16 +93,15 @@ def _process_single_layer( new_config: DeciLMConfig, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, - mlp_init_config: Optional[dict[str, Any]], + mlp_init_config: dict[str, Any] | None, linear_init_mode: LinearInitMode, ignored_keys: set, keys: dict, is_original_mha: bool, head_size: int, hidden_size: int, -) -> Tuple[Dict[str, torch.Tensor], Dict[str, str]]: - """ - Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). Thread-safe function for parallel layer processing. """ layer_out_state_dict = {} @@ -119,13 +119,13 @@ def _process_single_layer( o_key = f"{attn_prefix}.o_proj.{part}" attn_keys = [q_key, k_key, v_key, o_key] # Drop attn keys that don't exist and required to be in the new state_dict - attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] + attn_keys = [key for key in attn_keys if key in new_state_dict] if len(attn_keys) > 0 and all(key in keys for key in attn_keys): for key in attn_keys: keys_to_remove[key] = keys[key] if all(key not in ignored_keys for key in attn_keys): is_student_and_teacher_have_same_attention_implementation = all( - key in new_state_dict.keys() for key in attn_keys + key in new_state_dict for key in attn_keys ) if is_student_and_teacher_have_same_attention_implementation: if part == "weight": @@ -168,7 +168,7 @@ def _process_single_layer( else: linear_attn_key = f"{attn_prefix}.linear_attn.weight" - is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict.keys() + is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict if is_student_attn_replaced_with_linear: if linear_init_mode == LinearInitMode.Random: layer_out_state_dict[linear_attn_key] = new_state_dict[linear_attn_key] @@ -180,7 +180,7 @@ def _process_single_layer( raise ValueError(f"Unknown {linear_init_mode=}") else: # student attn random init - for new_key in new_state_dict.keys(): + for new_key in new_state_dict: if attn_prefix in new_key: layer_out_state_dict[new_key] = new_state_dict[new_key] @@ -190,7 +190,7 @@ def _process_single_layer( mlp_prefix = f"model.layers.{layer_idx}.mlp" linear_mlp_key = f"{mlp_prefix}.linear_mlp.weight" - is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict.keys() + is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict if is_student_mlp_replaced_with_linear: if linear_init_mode == LinearInitMode.Random: layer_out_state_dict[linear_mlp_key] = new_state_dict[linear_mlp_key] @@ -312,7 +312,7 @@ def _process_single_layer( ]: key_possibly_missing_in_student = f".{layer_idx}.{key_possibly_missing_in_student}" is_key_missing_from_student = ( - len([k for k in new_state_dict.keys() if key_possibly_missing_in_student in k]) == 0 + len([k for k in new_state_dict if key_possibly_missing_in_student in k]) == 0 ) if is_key_missing_from_student: for k in list(keys.keys()): @@ -331,12 +331,12 @@ def create_child_state_dict( gqa_init_mode: GQAInitMode, ignore_fn: IgnoreFn = default_ignore_fn, mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, - mlp_init_config: Optional[dict[str, Any]] = None, - owned_block_indexes: Optional[set[int]] = None, + mlp_init_config: dict[str, Any] | None = None, + owned_block_indexes: set[int] | None = None, linear_init_mode: LinearInitMode = LinearInitMode.Random, hidden_size_init_mode: HiddenSizeInitMode = HiddenSizeInitMode.CopyAsIs, - channel_importance_path: Optional[str] = None, - max_layer_workers: Optional[int] = None, # Now optional - will auto-calculate if None + channel_importance_path: str | None = None, + max_layer_workers: int | None = None, # Now optional - will auto-calculate if None ): mprint("=== Starting create_child_state_dict with optimizations ===") total_start_time = time.time() @@ -391,14 +391,14 @@ def create_child_state_dict( hidden_size = original_config.hidden_size - ignored_keys = set([key for key in original_state_dict.keys() if ignore_fn(key)]) + ignored_keys = set([key for key in original_state_dict if ignore_fn(key)]) for key in ignored_keys: aprint(f"Ignoring key {key} and taking its init from new_state_dict") out_state_dict[key] = new_state_dict[key] keys = { match.group(1) if (match := re.search(r"(h\.\d+\..*)", key)) is not None else key: key - for key in original_state_dict.keys() + for key in original_state_dict } setup_time = time.time() - setup_start_time mprint(f"Phase 1 - Setup and memory pre-allocation: {setup_time:.2f}s") @@ -527,7 +527,7 @@ def _generate_moe_keys(layer_idx: int, num_experts: int) -> tuple[str, dict[str, def _concatenate_experts_into_dense_ffn( original_state_dict: dict[str, torch.Tensor], - mlp_init_config: Optional[dict], + mlp_init_config: dict | None, hidden_size: int, layer_idx: int, child_block_config: BlockConfig, @@ -585,8 +585,7 @@ def _concatenate_experts_into_dense_ffn( "concat_dims and experts_weights must have the same keys" ) concat_routed_state_dict = { - name: torch.cat(experts_weights[name], dim=concat_dims[name]) - for name in concat_dims.keys() + name: torch.cat(experts_weights[name], dim=concat_dims[name]) for name in concat_dims } # turn the shared expert into a normal FFN. concatenate the pruned routed experts if needed. @@ -646,16 +645,16 @@ def _verify_state_dicts_match( def _init_mlp( *, - mlp_init_mode: Union[MlpInitMode, str], + mlp_init_mode: MlpInitMode | str, layer_idx: int, original_config: DeciLMConfig, - mlp_init_config: Optional[dict[str, Any]], + mlp_init_config: dict[str, Any] | None, original_state_dict: dict, new_state_dict: dict, new_config: DeciLMConfig, keys: dict[str, str], ignored_keys: set[str], - expert_idx: Optional[int] = None, + expert_idx: int | None = None, ) -> dict[str, torch.Tensor]: out_state_dict = {} @@ -680,7 +679,7 @@ def _init_mlp( projection_matrix = None for mlp_key in mlp_keys: expanded_dim = 1 if "down_proj" in mlp_key else 0 - if mlp_key in new_state_dict.keys(): + if mlp_key in new_state_dict: mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( mlp_init_mode, expanded_dim, @@ -700,17 +699,17 @@ def _init_mlp( def _init_mlp_module( - mlp_init_mode: Union[MlpInitMode, str], + mlp_init_mode: MlpInitMode | str, expanded_dim: int, new_item: torch.Tensor, new_config: DeciLMConfig, orig_item: torch.Tensor, original_config: DeciLMConfig, - mlp_init_config: Optional[dict[str, Any]], - pruned_filters: Optional[torch.Tensor] = None, - projection_matrix: Optional[dict[str, torch.Tensor]] = None, - mlp_prefix: Optional[str] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]: + mlp_init_config: dict[str, Any] | None, + pruned_filters: torch.Tensor | None = None, + projection_matrix: dict[str, torch.Tensor] | None = None, + mlp_prefix: str | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor] | None]: if isinstance(mlp_init_mode, str): mlp_init_mode = MlpInitMode(mlp_init_mode) assert orig_item.ndim == 2, f"{orig_item.ndim=}" @@ -779,14 +778,14 @@ def _init_mlp_module( def _init_moe_module( *, - mlp_init_mode: Union[MlpInitMode, str], - mlp_init_config: Optional[dict[str, Any]], + mlp_init_mode: MlpInitMode | str, + mlp_init_config: dict[str, Any] | None, layer_idx: int, orig_router_weight: torch.Tensor, orig_experts_weights: dict[str, list[torch.Tensor]], new_router_weight: torch.Tensor, new_experts_weights: dict[str, list[torch.Tensor]], -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]: +) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor] | None]: if isinstance(mlp_init_mode, str): mlp_init_mode = MlpInitMode(mlp_init_mode) @@ -849,11 +848,11 @@ def _prune_experts_by_score( return result_router_weight, result_experts_weights -def _load_expert_scores(mlp_init_config: Optional[dict[str, Any]]) -> list[list[int | float]]: +def _load_expert_scores(mlp_init_config: dict[str, Any] | None) -> list[list[int | float]]: assert mlp_init_config is not None if "expert_scores_file" in mlp_init_config: expert_scores_file = mlp_init_config["expert_scores_file"] - with open(expert_scores_file, "r") as f: + with open(expert_scores_file) as f: expert_scores = json.load(f) elif "activations_log_dir" in mlp_init_config: _cache_activations_log(mlp_init_config) @@ -1111,7 +1110,7 @@ def _init_attention_biases( bias_sd["k"] = bias_sd["k"][:, 0] bias_sd["v"] = bias_sd["v"][:, 0] elif gqa_init_mode == GQAInitMode.CopyAsIs: - for key in bias_sd.keys(): + for key in bias_sd: assert new_bias_sd[key].shape == bias_sd[key].shape, ( f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})" ) @@ -1227,8 +1226,7 @@ def _init_linear_attn( v_key: str, o_key: str, ) -> torch.Tensor: - """ - Init a linear layer that operates like an attention layer that assigns score 1 to the current token + """Init a linear layer that operates like an attention layer that assigns score 1 to the current token and score 0 to all others: out = (Wo @ Wv) @ x """ n_embd = parent_config.hidden_size @@ -1247,9 +1245,7 @@ def _init_linear_attn( def _init_linear_mlp(teacher_mlp_state_dict: dict[str, torch.Tensor]) -> torch.Tensor: - """ - A linear layer that does (W_down @ W_up) @ x, ignoring W_gate. - """ + """A linear layer that does (W_down @ W_up) @ x, ignoring W_gate.""" if "linear_mlp.weight" in teacher_mlp_state_dict: # if the teacher itself is a linear layer return teacher_mlp_state_dict["linear_mlp.weight"] @@ -1318,8 +1314,7 @@ def _parse_model_config_overrides( model_config_overrides_json: str | dict | Path | list[dict], n_layer: int, ) -> list[dict[str, Any]]: - """ - example model_config_overrides_json: + """Example model_config_overrides_json: { "attention": [{"n_heads_in_group": 2}], "ffn": [{"intermediate_size": 14336}] @@ -1368,11 +1363,10 @@ def _apply_hidden_size_pruning( new_config: DeciLMConfig, original_config: DeciLMConfig, hidden_size_init_mode: HiddenSizeInitMode, - channel_importance_path: Optional[str] = None, - owned_block_indexes: Optional[list[int]] = None, + channel_importance_path: str | None = None, + owned_block_indexes: list[int] | None = None, ) -> dict[str, torch.Tensor]: - """ - Apply hidden size pruning to all layers that depend on hidden_size. + """Apply hidden size pruning to all layers that depend on hidden_size. This includes embeddings, layer norms, and any linear layers that haven't been handled yet. """ if isinstance(hidden_size_init_mode, str): @@ -1387,7 +1381,7 @@ def _apply_hidden_size_pruning( # Load channel ranking if needed if hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: if channel_importance_path is not None: - with open(channel_importance_path, "r") as f: + with open(channel_importance_path) as f: channel_ranking = json.load(f)["channel_importance_ranking"] else: raise ValueError( @@ -1580,12 +1574,10 @@ def _prune_hidden_size_dimension( original_tensor: torch.Tensor, new_hidden_size: int, hidden_size_init_mode: HiddenSizeInitMode, - channel_ranking: Optional[list[int]] = None, + channel_ranking: list[int] | None = None, dim: int = -1, ) -> torch.Tensor: - """ - Prune a tensor along the specified dimension to match the new hidden size. - """ + """Prune a tensor along the specified dimension to match the new hidden size.""" original_size = original_tensor.shape[dim] if hidden_size_init_mode == HiddenSizeInitMode.Random: diff --git a/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py similarity index 87% rename from modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py rename to modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index f06db92fbe..46e403c5f4 100644 --- a/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -16,16 +16,14 @@ """TODO Add description""" -import argparse import json import time -from typing import Optional import torch import yaml -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM -from modelopt.torch._compress.tools.bypassed_training.child_init import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( GQAInitMode, HiddenSizeInitMode, LinearInitMode, @@ -33,16 +31,16 @@ create_child_state_dict, update_model_config, ) -from modelopt.torch._compress.tools.checkpoint_utils import ( +from modelopt.torch.puzzletron.tools.checkpoint_utils import ( copy_tokenizer, load_model_config, load_state_dict, ) -from modelopt.torch._compress.tools.checkpoint_utils_hf import ( +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( _save_checkpoint, copy_deci_lm_hf_code, ) -from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch.puzzletron.tools.logger import mprint """ @@ -87,7 +85,7 @@ echo "MODEL_CONFIG_OVERRIDES_JSON:" echo "${MODEL_CONFIG_OVERRIDES_JSON}" -python -m modelopt.torch._compress.tools.bypassed_training.init_child_from_parent \ +python -m modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent \ --parent_checkpoint_dir="$PARENT_DIR" \ --model_config_overrides_json="$MODEL_CONFIG_OVERRIDES_JSON" \ --output_checkpoint_dir="$OUTPUT_DIR" \ @@ -102,15 +100,14 @@ def init_child_from_parent( output_checkpoint_dir: str, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, - mlp_init_config_yaml: Optional[str], + mlp_init_config_yaml: str | None, linear_init_mode: LinearInitMode, - hidden_size_init_mode: Optional[HiddenSizeInitMode] = None, - channel_importance_path: Optional[str] = None, - max_workers: Optional[int] = None, # Auto-calculate optimal workers if None - max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None + hidden_size_init_mode: HiddenSizeInitMode | None = None, + channel_importance_path: str | None = None, + max_workers: int | None = None, # Auto-calculate optimal workers if None + max_layer_workers: int | None = None, # Auto-calculate optimal workers if None ) -> None: - """ - Init child models from parent models in the style of bypass training, + """Init child models from parent models in the style of bypass training, but without having to run the entire bypass pipeline. I/O Optimization Parameters: @@ -210,7 +207,7 @@ def init_child_from_parent( total_core_time = create_child_state_dict_time + save_checkpoint_time actual_layer_workers = max_layer_workers if max_layer_workers else "auto" actual_io_workers = max_workers if max_workers else "auto" - mprint(f"\n=== PROFILING SUMMARY ===") + mprint("\n=== PROFILING SUMMARY ===") mprint( f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)" ) @@ -219,4 +216,4 @@ def init_child_from_parent( ) mprint(f"Total core processing: {total_core_time:.2f}s") mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") - mprint(f"=========================\n") + mprint("=========================\n") diff --git a/modelopt/torch/_compress/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py similarity index 92% rename from modelopt/torch/_compress/tools/checkpoint_utils.py rename to modelopt/torch/puzzletron/tools/checkpoint_utils.py index 43d3c43641..f08b89e449 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py @@ -14,8 +14,7 @@ # limitations under the License. # mypy: ignore-errors -""" -It provides general utilities for loading and initializing PyTorch model checkpoints, +"""It provides general utilities for loading and initializing PyTorch model checkpoints, particularly for DeciLM models. """ @@ -31,8 +30,8 @@ from transformers import AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME -from modelopt.torch._compress.tools.checkpoint_utils_hf import load_model_config -from modelopt.torch._compress.tools.common import infer_weights_dtype +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch.puzzletron.tools.common import infer_weights_dtype SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" PTH_SUBBLOCKS_DIR_NAME = "subblocks" @@ -56,7 +55,7 @@ def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or ( checkpoint_dir / SAFE_WEIGHTS_NAME ).exists(): - from modelopt.torch._compress.tools.sharded_checkpoint_utils import ( + from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( load_sharded_state_dict, # local import to avoid circular import ) @@ -124,9 +123,7 @@ def init_empty_module( def skip_init(module_cls, *args, **kwargs) -> nn.Module: - """ - Heavily inspired by torch.nn.utils.skip_init but does not require the module to accept a "device" kwarg. - """ + """Heavily inspired by torch.nn.utils.skip_init but does not require the module to accept a "device" kwarg.""" if not issubclass(module_cls, torch.nn.Module): raise RuntimeError(f"Expected a Module; got {module_cls}") @@ -165,8 +162,7 @@ def copy_tokenizer( target_dir: Path | str, on_failure: Literal["raise", "warn"] = "raise", ) -> None: - """ - Prefer loading the tokenizer from huggingface hub (when tokenizer_name.txt file is available) + """Prefer loading the tokenizer from huggingface hub (when tokenizer_name.txt file is available) to avoid collision between transformers versions. """ source_tokenizer_name_path = Path(source_dir_or_tokenizer_name) / "tokenizer_name.txt" diff --git a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py similarity index 94% rename from modelopt/torch/_compress/tools/checkpoint_utils_hf.py rename to modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 3c73498d5f..f52c12d26f 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -14,8 +14,7 @@ # limitations under the License. # mypy: ignore-errors -""" -Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, +"""Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, particularly for DeciLM models. """ @@ -34,13 +33,13 @@ from safetensors.torch import save_file as safe_save_file from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from modelopt.torch._compress.decilm import deci_lm_hf_code -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM -from modelopt.torch._compress.tools.common import infer_weights_dtype -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.post_init_sparse import SparsityMethod -from modelopt.torch._compress.tools.robust_json import json_dumps +from modelopt.torch.puzzletron.decilm import deci_lm_hf_code +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.puzzletron.tools.common import infer_weights_dtype +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.post_init_sparse import SparsityMethod +from modelopt.torch.puzzletron.tools.robust_json import json_dumps SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" PTH_SUBBLOCKS_DIR_NAME = "subblocks" @@ -70,11 +69,10 @@ def load_checkpoint( model_config_overrides: dict | None = None, ignore_unexpected_config_keys: bool = False, ) -> DeciLMForCausalLM: - """ - Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your + """Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your local repo code, not the code inside the checkpoint. """ - from modelopt.torch._compress.tools.checkpoint_utils import ( + from modelopt.torch.puzzletron.tools.checkpoint_utils import ( load_state_dict, # prevent circular import ) @@ -193,7 +191,7 @@ def _save_checkpoint( def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: - from modelopt.torch._compress.tools.checkpoint_utils import ( + from modelopt.torch.puzzletron.tools.checkpoint_utils import ( load_state_dict, # prevent circular import ) @@ -374,8 +372,7 @@ def _write_file_process_safe( path: Path | str, write_fn: Callable[[Any, BinaryIO], None] = _write_text, ) -> None: - """ - Write a file in a multi-process safe way. + """Write a file in a multi-process safe way. If another process tries to write the same file using this method, the current process "gives up" and assumes that the matter is being taken care of by another process. @@ -444,9 +441,7 @@ def save_model_config(model_config: DeciLMConfig, checkpoint_dir: Path | str) -> def copy_deci_lm_hf_code(output_dir: Path | str) -> None: - """ - Copy the deci_lm_hf_code directory to the output directory. - """ + """Copy the deci_lm_hf_code directory to the output directory.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) code_dir = Path(deci_lm_hf_code.__file__).parent diff --git a/modelopt/torch/_compress/tools/common.py b/modelopt/torch/puzzletron/tools/common.py similarity index 100% rename from modelopt/torch/_compress/tools/common.py rename to modelopt/torch/puzzletron/tools/common.py diff --git a/modelopt/torch/_compress/tools/hydra_utils.py b/modelopt/torch/puzzletron/tools/hydra_utils.py similarity index 100% rename from modelopt/torch/_compress/tools/hydra_utils.py rename to modelopt/torch/puzzletron/tools/hydra_utils.py diff --git a/modelopt/torch/_compress/tools/kd_model.py b/modelopt/torch/puzzletron/tools/kd_model.py similarity index 100% rename from modelopt/torch/_compress/tools/kd_model.py rename to modelopt/torch/puzzletron/tools/kd_model.py diff --git a/modelopt/torch/_compress/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py similarity index 92% rename from modelopt/torch/_compress/tools/logger.py rename to modelopt/torch/puzzletron/tools/logger.py index 3e8e213ca2..e4b87e3770 100644 --- a/modelopt/torch/_compress/tools/logger.py +++ b/modelopt/torch/puzzletron/tools/logger.py @@ -48,13 +48,15 @@ def __init__(self, name, level=logging.DEBUG): self.world_size = int(os.environ.get("WORLD_SIZE", 1)) def dist_log(self, msg: str, ranks: str = "main"): - """ - Log parameter msg with the given ranks. - parameter ranks: - "all": log with all ranks - "main": log with only rank 0 in node 0 - "last": log with only rank -1 in node 0 - "local_main": log with only rank 0 in all nodes + """Log parameter msg with the given ranks. + + Args: + msg: The message to log. + ranks: The ranks to log the message to. Choices are: + "all": log with all ranks + "main": log with only rank 0 in node 0 + "last": log with only rank -1 in node 0 + "local_main": log with only rank 0 in all nodes """ # print(msg, ranks) if ranks not in ["all", "main", "local_main", "last"]: diff --git a/modelopt/torch/_compress/tools/post_init_sparse.py b/modelopt/torch/puzzletron/tools/post_init_sparse.py similarity index 94% rename from modelopt/torch/_compress/tools/post_init_sparse.py rename to modelopt/torch/puzzletron/tools/post_init_sparse.py index 824d0856ca..e2c45c4030 100644 --- a/modelopt/torch/_compress/tools/post_init_sparse.py +++ b/modelopt/torch/puzzletron/tools/post_init_sparse.py @@ -17,7 +17,7 @@ from torch import nn from torch.nn.utils.prune import custom_from_mask -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM """ Converts a state dictionary from PyTorch's pruning format (with _orig and _mask suffixes) @@ -27,9 +27,7 @@ class SparsityMethod: def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """ - gets a model state_dict, returns a state_dict-like mask_dict with masks - """ + """Gets a model state_dict, returns a state_dict-like mask_dict with masks""" @staticmethod def fix_state_dict_inplace(state_dict, verbose=False, change_dtype=False): @@ -99,9 +97,7 @@ def do_sparsity(self, model: DeciLMForCausalLM, mask_dict=None): class SparsityMethod2o4(SparsityMethod): def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """ - gets a model state_dict, returns a state_dict-like mask_dict with masks - """ + """Gets a model state_dict, returns a state_dict-like mask_dict with masks""" mask_dict = {} for key, val in state_dict.items(): orig_size = val.shape diff --git a/modelopt/torch/_compress/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py similarity index 100% rename from modelopt/torch/_compress/tools/robust_json.py rename to modelopt/torch/puzzletron/tools/robust_json.py diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py similarity index 94% rename from modelopt/torch/_compress/tools/sharded_checkpoint_utils.py rename to modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py index 7a247bbdf0..1cb5e8489a 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -14,8 +14,7 @@ # limitations under the License. # mypy: ignore-errors -""" -Provides utilities for distributed loading, saving, and manipulation of +"""Provides utilities for distributed loading, saving, and manipulation of large language model checkpoints across multiple GPUs/processes. """ @@ -28,25 +27,23 @@ import torch import torch.distributed import torch.nn as nn -from huggingface_hub import split_torch_state_dict_into_shards from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file -from tqdm import tqdm from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, DeciLMForCausalLM, rope_type_to_class, ) -from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.utils.utils import EmptyInitOnDevice +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config, load_state_dict +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.utils import EmptyInitOnDevice class DummyModule(nn.Module): @@ -243,7 +240,7 @@ def create_sharded_model( def load_state_dict_to_shards( model_shard: torch.nn.Module, loaded_state_dict: dict | None = None ) -> None: - from modelopt.torch._compress.sewing_kit.utils import ( + from modelopt.torch.puzzletron.sewing_kit.utils import ( distributed_isend_obj, distributed_recv_obj, ) @@ -291,9 +288,7 @@ def load_state_dict_to_shards( def save_sharded_model( model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path ): - """ - out_path is usually output_checkpoint_path / "model.safetensors" - """ + """out_path is usually output_checkpoint_path / "model.safetensors" """ dist.barrier() if isinstance(model_shard, torch.nn.Module): @@ -351,9 +346,7 @@ def load_sharded_state_dict( keys_to_load: Iterable[str] | None = None, device: torch.device | str = "cpu", ) -> dict[str, torch.Tensor]: - """ - keys_to_load: entire state_dict if None, else partial state_dict containing only these keys - """ + """keys_to_load: entire state_dict if None, else partial state_dict containing only these keys""" shard_paths = _resolve_shard_paths(model_name_or_path) # print(f"shard_paths: {shard_paths}") partial_state_dict = {} diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py similarity index 73% rename from modelopt/torch/_compress/tools/validate_model.py rename to modelopt/torch/puzzletron/tools/validate_model.py index 456f9fab87..6c3dc3640c 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Provides a function to validate a model. Runs a model forward pass on a dataset and calculates +"""Provides a function to validate a model. Runs a model forward pass on a dataset and calculates the loss, and optionally registers hooks to capture the inputs and the outputs of pytorch modules that are used for activation scoring for pruning. @@ -36,19 +35,19 @@ ) import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( +from modelopt.torch.puzzletron.activation_scoring.activation_hooks.utils import ( register_activation_hooks, ) -from modelopt.torch._compress.tools.checkpoint_utils_hf import load_checkpoint -from modelopt.torch._compress.tools.logger import aprint, mprint -from modelopt.torch._compress.tools.sharded_checkpoint_utils import load_and_shard_model -from modelopt.torch._compress.utils.data.dataloaders import create_validation_dataloader -from modelopt.torch._compress.utils.parsing import simple_parse_args_string -from modelopt.torch._compress.utils.validate_runtime_pipeline import ( +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_checkpoint +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.utils.data.dataloaders import create_validation_dataloader +from modelopt.torch.puzzletron.utils.parsing import simple_parse_args_string +from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( HiddenStatesAndLMHead, calculate_losses_pipeline, ) -from modelopt.torch._compress.utils.validation import calculate_losses +from modelopt.torch.puzzletron.utils.validation import calculate_losses """ Two goals: @@ -80,40 +79,45 @@ def validate_model( Args: args: Configuration object containing the following attributes: - Model Configuration: - - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. - Required unless model is passed directly. - - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. - - Dataset Configuration: - - dataset_path (str): Path to the validation dataset. - - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. - - data_column (str): Column name in dataset containing text data. - - block_size (int): Maximum sequence length for tokenization. - - eval_samples (int, optional): Number of samples to evaluate. Uses all if None. - - val_dataset_name (str): Name of validation dataset split. - - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. - - load_dataset_fn (callable, optional): Custom function to load the dataset. - - Data Processing: - - micro_batch_size (int): Batch size for evaluation. - - seed (int): Random seed for reproducibility. - - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. - - varlen (bool): Enable variable-length sequences. - - bos_rate (float): Rate of adding BOS token. - - fim_rate (float): Fill-in-the-middle rate for code completion tasks. - - fim_spm_rate (float): SPM-based fill-in-the-middle rate. - - Activation Hooks: - - activations_log_dir (str, optional): Directory to log activation scores. If provided, - hooks will be registered to capture activations. - - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. - If string, comma-separated format: "arg1=val1,arg2=val2". - - Execution Options: - - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. - - write_results (bool): Write validation results to file. + Model Configuration attributes: + + - ``model_name_or_path`` (str): Path to model checkpoint or HuggingFace model name. + Required unless model is passed directly. + - ``model_dtype`` (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - ``autocast_dtype`` (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration attributes: + + - ``dataset_path`` (str): Path to the validation dataset. + - ``tokenizer_name`` (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. + - ``data_column`` (str): Column name in dataset containing text data. + - ``block_size`` (int): Maximum sequence length for tokenization. + - ``eval_samples`` (int, optional): Number of samples to evaluate. Uses all if None. + - ``val_dataset_name`` (str): Name of validation dataset split. + - ``source_datasets_to_discard`` (list[str], optional): List of source datasets to exclude. + - ``load_dataset_fn`` (callable, optional): Custom function to load the dataset. + + Data Processing attributes: + + - ``micro_batch_size`` (int): Batch size for evaluation. + - ``seed`` (int): Random seed for reproducibility. + - ``shuffle_seed`` (int, optional): Seed for shuffling data. Uses seed if None. + - ``varlen`` (bool): Enable variable-length sequences. + - ``bos_rate`` (float): Rate of adding BOS token. + - ``fim_rate`` (float): Fill-in-the-middle rate for code completion tasks. + - ``fim_spm_rate`` (float): SPM-based fill-in-the-middle rate. + + Activation Hooks attributes: + + - ``activations_log_dir`` (str, optional): Directory to log activation scores. + If provided, hooks will be registered to capture activations. + - ``activation_hooks_kwargs`` (str or dict, optional): Arguments for activation hooks. + If string, comma-separated format: "arg1=val1,arg2=val2". + + Execution Options attributes: + + - ``calc_losses_on_cpu`` (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. + - ``write_results`` (bool): Write validation results to file. model: Pre-loaded model. If None, will be loaded from args.model_name_or_path. tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args. @@ -121,16 +125,17 @@ def validate_model( return_hidden_states: Whether to return hidden states from the model. pipeline_parallel: Enable pipeline parallelism for large models. calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. - False calculates only a small suite for efficiency. + False calculates only a small suite for efficiency. val_dataloader: Pre-created validation dataloader. If None, will be created from args. Returns: A tuple containing: + - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. + Returns (None, None) if not on master rank. """ - if val_dataloader is None: val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None validation_full_iters = ( @@ -157,7 +162,7 @@ def validate_model( ) # Create checkpoint manager with hooks - from modelopt.torch._compress.utils.checkpoint_manager import ScoringCheckpointManager + from modelopt.torch.puzzletron.utils.checkpoint_manager import ScoringCheckpointManager mprint( f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}" diff --git a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py similarity index 70% rename from modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py rename to modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index 6bc4d11b35..4e3266df4f 100644 --- a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -21,11 +21,9 @@ # mypy: ignore-errors import json -import shutil import warnings from functools import partial from pathlib import Path -from typing import Optional import torch from omegaconf import DictConfig @@ -33,25 +31,25 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.replacement_library.replacement_library import ReplacementLibrary -from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement -from modelopt.torch._compress.tools import validate_model -from modelopt.torch._compress.tools.checkpoint_utils import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.replacement_library.replacement_library import ReplacementLibrary +from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement +from modelopt.torch.puzzletron.tools import validate_model +from modelopt.torch.puzzletron.tools.checkpoint_utils import ( SAFETENSORS_SUBBLOCKS_DIR_NAME, copy_tokenizer, ) -from modelopt.torch._compress.tools.checkpoint_utils_hf import ( +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( copy_deci_lm_hf_code, save_checkpoint, save_safetensors_index, ) -from modelopt.torch._compress.tools.validation_utils import ( +from modelopt.torch.puzzletron.tools.validation_utils import ( validate_model_and_extract_hidden_states, validate_model_with_teacher_similarity_metrics, ) -from modelopt.torch._compress.utils.parsing import get_nested_key, parse_path -from modelopt.torch._compress.utils.validate_runtime_pipeline import perform_pipeline_stitches +from modelopt.torch.puzzletron.utils.parsing import get_nested_key +from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import perform_pipeline_stitches """ Usage Example: @@ -70,51 +68,58 @@ def validate_puzzle_solutions(args: DictConfig) -> None: Args: args: Configuration object containing the following attributes: - Puzzle Configuration (Required): - - replacement_library_path (Path): Path to the replacement library JSON file. - - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. - - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. - Validates all solutions if None. - - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. - - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. - - skip_validation (bool): If True, skip model validation and only save models if requested. - - save_models (bool): If True, save realized model checkpoints for each solution. - - Teacher/Tokenizer Configuration: - - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided. - - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. - - Model Configuration (Required if skip_validation=False): - - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. - - Dataset Configuration (Required if skip_validation=False): - - dataset_path (str): Path to the validation dataset. - - data_column (str): Column name in dataset containing text data. - - block_size (int): Maximum sequence length for tokenization. - - eval_samples (int, optional): Number of samples to evaluate. - - val_dataset_name (str): Name of validation dataset split. - - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. - - load_dataset_fn (callable, optional): Custom function to load the dataset. - - Data Processing (Required if skip_validation=False): - - micro_batch_size (int): Batch size for evaluation. - - seed (int): Random seed for reproducibility. - - shuffle_seed (int, optional): Seed for shuffling data. - - varlen (bool): Enable variable-length sequences. - - bos_rate (float): Rate of adding BOS token. - - fim_rate (float): Fill-in-the-middle rate for code completion tasks. - - fim_spm_rate (float): SPM-based fill-in-the-middle rate. - - Output Configuration: - - output_dir (Path, optional): Directory to save validation results. - Auto-generated from solutions_path if not provided. - - Execution Options (Optional if skip_validation=False): - - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. - - write_results (bool): Write validation results to file. - - activations_log_dir (str, optional): Directory to log activation scores. - - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + Puzzle Configuration (Required) attributes: + + - ``replacement_library_path`` (Path): Path to the replacement library JSON file. + - ``solutions_path`` (Path): Path to puzzle solutions JSON file or directory containing solution files. + - ``solutions_to_validate`` (list[int], optional): Indices of specific solutions to validate. + Validates all solutions if None. + - ``sort_solutions_by`` (str, optional): JSON field path to sort solutions by before validation. + - ``bigger_is_better`` (bool): If True, sort solutions in descending order. Used with sort_solutions_by. + - ``skip_validation`` (bool): If True, skip model validation and only save models if requested. + - ``save_models`` (bool): If True, save realized model checkpoints for each solution. + + Teacher/Tokenizer Configuration attributes: + + - ``teacher_dir`` (Path, optional): Path to teacher model directory. Auto-inferred if not provided. + - ``tokenizer_name`` (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. + + Model Configuration (Required if skip_validation=False) attributes: + + - ``model_dtype`` (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - ``autocast_dtype`` (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration (Required if skip_validation=False) attributes: + + - ``dataset_path`` (str): Path to the validation dataset. + - ``data_column`` (str): Column name in dataset containing text data. + - ``block_size`` (int): Maximum sequence length for tokenization. + - ``eval_samples`` (int, optional): Number of samples to evaluate. + - ``val_dataset_name`` (str): Name of validation dataset split. + - ``source_datasets_to_discard`` (list[str], optional): List of source datasets to exclude. + - ``load_dataset_fn`` (callable, optional): Custom function to load the dataset. + + Data Processing (Required if skip_validation=False) attributes: + + - ``micro_batch_size`` (int): Batch size for evaluation. + - ``seed`` (int): Random seed for reproducibility. + - ``shuffle_seed`` (int, optional): Seed for shuffling data. + - ``varlen`` (bool): Enable variable-length sequences. + - ``bos_rate`` (float): Rate of adding BOS token. + - ``fim_rate`` (float): Fill-in-the-middle rate for code completion tasks. + - ``fim_spm_rate`` (float): SPM-based fill-in-the-middle rate. + + Output Configuration attributes: + + - ``output_dir`` (Path, optional): Directory to save validation results. + Auto-generated from solutions_path if not provided. + + Execution Options (Optional if skip_validation=False) attributes: + + - ``calc_losses_on_cpu`` (bool): Calculate losses on CPU to avoid OOM. + - ``write_results`` (bool): Write validation results to file. + - ``activations_log_dir`` (str, optional): Directory to log activation scores. + - ``activation_hooks_kwargs`` (str or dict, optional): Arguments for activation hooks. Returns: None. Saves validation results and optionally model checkpoints to disk. @@ -273,7 +278,7 @@ def _extract_layer_replacements_from_puzzle_solution( def load_puzzle_solutions( solutions_path: Path, - sort_solutions_by: Optional[str], + sort_solutions_by: str | None, bigger_is_better: bool, ) -> list[dict]: assert solutions_path.exists(), f"{solutions_path=} does not exist" diff --git a/modelopt/torch/_compress/tools/validation_utils.py b/modelopt/torch/puzzletron/tools/validation_utils.py similarity index 88% rename from modelopt/torch/_compress/tools/validation_utils.py rename to modelopt/torch/puzzletron/tools/validation_utils.py index 6f0b1fcb5d..697977cdaf 100644 --- a/modelopt/torch/_compress/tools/validation_utils.py +++ b/modelopt/torch/puzzletron/tools/validation_utils.py @@ -21,7 +21,7 @@ # mypy: ignore-errors from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import torch from omegaconf import DictConfig, OmegaConf @@ -29,13 +29,13 @@ from transformers import PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.tools import validate_model -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.robust_json import json_dump -from modelopt.torch._compress.utils.validation import LowMemorySparseTensor +from modelopt.torch.puzzletron.tools import validate_model +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.puzzletron.utils.validation import LowMemorySparseTensor if TYPE_CHECKING: - from modelopt.torch._compress.sewing_kit import StitchedModule + from modelopt.torch.puzzletron.sewing_kit import StitchedModule def validate_model_and_extract_hidden_states( @@ -44,7 +44,7 @@ def validate_model_and_extract_hidden_states( tokenizer: PreTrainedTokenizerBase, output_dir: str | Path, model_name: str, - extra_payload: Optional[dict[str, Any]] = None, + extra_payload: dict[str, Any] | None = None, pipeline_parallel: bool = False, val_dataloader=None, ) -> list[torch.Tensor | LowMemorySparseTensor]: @@ -77,7 +77,7 @@ def validate_model_with_teacher_similarity_metrics( target_hidden_states_per_batch: list[torch.Tensor], output_dir: str | Path, model_name: str, - extra_payload: Optional[dict[str, Any]] = None, + extra_payload: dict[str, Any] | None = None, pipeline_parallel: bool = False, calculate_full_score_ablations: bool = False, val_dataloader=None, diff --git a/modelopt/torch/_compress/utils/checkpoint_manager.py b/modelopt/torch/puzzletron/utils/checkpoint_manager.py similarity index 93% rename from modelopt/torch/_compress/utils/checkpoint_manager.py rename to modelopt/torch/puzzletron/utils/checkpoint_manager.py index b43c37481d..3fc4bf87e2 100644 --- a/modelopt/torch/_compress/utils/checkpoint_manager.py +++ b/modelopt/torch/puzzletron/utils/checkpoint_manager.py @@ -13,25 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Checkpoint manager for activation hook scoring with periodic saves and resume support. -""" +"""Checkpoint manager for activation hook scoring with periodic saves and resume support.""" import json import time from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.logger import aprint, mprint class ScoringCheckpointManager: """Manages checkpointing for activation hook scoring with periodic saves.""" def __init__(self, checkpoint_dir: str, activation_hooks=None, checkpoint_interval: int = 100): - """ - Initialize checkpoint manager. + """Initialize checkpoint manager. Args: checkpoint_dir: Directory to save checkpoints @@ -63,9 +60,8 @@ def __init__(self, checkpoint_dir: str, activation_hooks=None, checkpoint_interv if self.is_main_process: self.checkpoint_dir.mkdir(parents=True, exist_ok=True) - def load_checkpoint(self) -> Optional[Dict[str, Any]]: - """ - Load existing checkpoint if available, including hook states. + def load_checkpoint(self) -> dict[str, Any] | None: + """Load existing checkpoint if available, including hook states. Returns: Dict with checkpoint info or None if no checkpoint exists @@ -76,7 +72,7 @@ def load_checkpoint(self) -> Optional[Dict[str, Any]]: return None try: - with open(self.progress_file, "r") as f: + with open(self.progress_file) as f: checkpoint_data = json.load(f) # Validate checkpoint @@ -114,8 +110,7 @@ def load_checkpoint(self) -> Optional[Dict[str, Any]]: return None def load_hook_states(self, activation_hooks) -> bool: - """ - Load hook states from checkpoint files. + """Load hook states from checkpoint files. Args: activation_hooks: Hook objects to load states into @@ -173,8 +168,7 @@ def should_skip_batch(self, batch_idx: int) -> bool: return should_skip def update_progress(self, batch_idx: int, total_batches: int): - """ - Update progress and potentially save checkpoint. + """Update progress and potentially save checkpoint. Args: batch_idx: Current batch index @@ -207,8 +201,7 @@ def update_progress(self, batch_idx: int, total_batches: int): dist.barrier() def save_checkpoint(self): - """ - Save current checkpoint to disk (progress info only). + """Save current checkpoint to disk (progress info only). Hook states are saved separately in update_progress. """ try: diff --git a/modelopt/torch/_compress/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py similarity index 97% rename from modelopt/torch/_compress/utils/data/dataloaders.py rename to modelopt/torch/puzzletron/utils/data/dataloaders.py index 865ad89fbc..892d1f3c2c 100644 --- a/modelopt/torch/_compress/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -DataLoader utilities for language model training and validation. -""" +"""DataLoader utilities for language model training and validation.""" from collections.abc import Callable, Mapping, Sequence from functools import partial @@ -30,8 +28,8 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.utils.data.dataset import ConstantLengthDataset +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.data.dataset import ConstantLengthDataset def collate_none_fn( diff --git a/modelopt/torch/_compress/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py similarity index 96% rename from modelopt/torch/_compress/utils/data/dataset.py rename to modelopt/torch/puzzletron/utils/data/dataset.py index 342b0821ef..a71049105e 100644 --- a/modelopt/torch/_compress/utils/data/dataset.py +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -14,14 +14,12 @@ # limitations under the License. # mypy: ignore-errors import functools -from typing import Optional, Sequence +from collections.abc import Sequence import numpy as np import torch from torch.utils.data import IterableDataset -from modelopt.torch._compress.tools.logger import aprint, mprint - FIM_TOKEN_START = " int: """Calculate the key-value dimension for grouped-query attention. TODO: Consider a better place for this function. + Args: n_heads_in_group: Number of attention heads per key-value group. n_head: Total number of attention heads. @@ -52,6 +53,7 @@ def raise_unknown_subblock_config_error(subblock_config: Any) -> None: """Raise an error for invalid subblock configuration types. TODO: Consider a better place for this function. + Args: subblock_config: The invalid subblock configuration object. @@ -67,6 +69,7 @@ def sizeof_dtype(dtype: torch.dtype) -> int | float: """Return the size in bytes of the given data type. TODO: Consider a better place for this function. + Args: dtype: PyTorch data type or custom type string (e.g., 'nvfp4'). @@ -122,10 +125,10 @@ def solution_to_str(block_configs: list[dict[str, Any] | BlockConfig]) -> str: def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: - """ - Convert a BlockConfig to a human-readable string representation. + """Convert a BlockConfig to a human-readable string representation. TODO: Consider a better place for this function. + Args: block_config: BlockConfig dataclass or dict containing attention and ffn configs. @@ -150,6 +153,7 @@ def subblock_config_to_str( """Convert a subblock config (FFN, Attention, Mamba, or MoE) to string. TODO: Consider a better place for this function. + Args: subblock_config: FFNConfig, AttentionConfig dataclass or dict. subblock_name: Name of subblock ('ffn', 'attention', 'mamba', 'moe'). @@ -212,8 +216,7 @@ def subblock_config_to_str( class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): def __init__(self, device=None, dtype=None): - """ - Create tensors with given device and dtype and don't run initialization + """Create tensors with given device and dtype and don't run initialization (but instead use "empty tensors", i.e. uninitialized memory). device: `torch.device` to work with @@ -222,8 +225,8 @@ def __init__(self, device=None, dtype=None): Example:: with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): model = LLaMA(model_config) - model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth"))""" - + model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth")) + """ self.device = device self.dtype = dtype diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py similarity index 92% rename from modelopt/torch/_compress/utils/validate_runtime_pipeline.py rename to modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py index b3be70644b..db1e8f2cea 100644 --- a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. +"""Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. Coordinates forward passes and loss computation through model shards distributed across GPUs using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation. @@ -29,11 +28,11 @@ from tqdm import tqdm import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMForCausalLM, LMHead, ) -from modelopt.torch._compress.sewing_kit import ( +from modelopt.torch.puzzletron.sewing_kit import ( ExternalTarget, InputArgs, ModuleTarget, @@ -41,15 +40,15 @@ RemoteTarget, StitchedModule, ) -from modelopt.torch._compress.sewing_kit.core import InputReducer -from modelopt.torch._compress.sewing_kit.utils import ( +from modelopt.torch.puzzletron.sewing_kit.core import InputReducer +from modelopt.torch.puzzletron.sewing_kit.utils import ( distributed_recv_obj, distributed_send_obj, fake_tensor, ) -from modelopt.torch._compress.tools.checkpoint_utils import init_module_with_state_dict -from modelopt.torch._compress.tools.sharded_checkpoint_utils import DummyBlock -from modelopt.torch._compress.utils.validation import _organize_outputs, calculate_batch_outputs +from modelopt.torch.puzzletron.tools.checkpoint_utils import init_module_with_state_dict +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import DummyBlock +from modelopt.torch.puzzletron.utils.validation import _organize_outputs, calculate_batch_outputs class HiddenStatesAndLMHead(list): @@ -70,8 +69,7 @@ def calculate_losses_pipeline( checkpoint_manager=None, autocast_dtype: torch.dtype = torch.bfloat16, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: - """ - Do model forward on each batch and calculate LM loss. + """Do model forward on each batch and calculate LM loss. Optionally also calculate kl_div loss and other metrics from given target_hidden_states_per_batch. Optionally return hidden states per batch. Does not support data-parallel. diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/puzzletron/utils/validation.py similarity index 97% rename from modelopt/torch/_compress/utils/validation.py rename to modelopt/torch/puzzletron/utils/validation.py index d970105e68..0fff907549 100644 --- a/modelopt/torch/_compress/utils/validation.py +++ b/modelopt/torch/puzzletron/utils/validation.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Model validation and loss calculation utilities for single-GPU and multi-GPU setups. +"""Model validation and loss calculation utilities for single-GPU and multi-GPU setups. Also provides helper functions for loss metrics, KL divergence, JS divergence, and similarity losses for knowledge distillation. @@ -34,7 +33,7 @@ from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper from typing_extensions import Self -from modelopt.torch._compress.tools import kd_model +from modelopt.torch.puzzletron.tools import kd_model class UnshardedLowMemorySparseTensor: @@ -94,8 +93,7 @@ def calculate_losses( return_probs: bool = False, checkpoint_manager=None, ) -> tuple[dict[str, dict], None] | tuple[None, None]: - """ - Do model forward on each batch and calculate LM loss. + """Do model forward on each batch and calculate LM loss. Works on lit-llama models (single gpu) and huggingface models (can be multi gpu). Does not support data-parallel. @@ -313,8 +311,7 @@ def _calculate_teacher_similarity_scores( target_logits: torch.Tensor, calculate_full_score_ablations: bool, ) -> dict[str, list[float]]: - """ - hidden_states: [batch, tokens, n_embd] + """hidden_states: [batch, tokens, n_embd] target_hidden_states: [batch, tokens, n_embd] logits: [batch, tokens, vocab] target_logits: [batch, tokens, vocab] @@ -443,9 +440,7 @@ class ClipEpsilon(Enum): def _logits_to_logprobs( logits: torch.Tensor, clip_epsilon: ClipEpsilon, epsilon_factor: float ) -> torch.Tensor: - """ - logits: [tokens, vocab] - """ + """logits: [tokens, vocab]""" logprobs = logits.log_softmax( -1 ) # must normalize logits before clipping otherwise log(1/voacb) means nothing @@ -467,8 +462,7 @@ def kl_div( clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, epsilon_factor: float = 1.0, ) -> float: - """ - Kullback-Leibler Divergence for a single sample. + """Kullback-Leibler Divergence for a single sample. logits: [tokens, vocab] target_probs: [tokens, vocab] """ @@ -487,8 +481,7 @@ def js_div( clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, epsilon_factor: float = 1.0, ) -> float: - """ - Jensen-Shannon Divergence for a single sample. + """Jensen-Shannon Divergence for a single sample. logits: [tokens, vocab] target_probs: [tokens, vocab] """ @@ -508,8 +501,7 @@ def tv_dist( clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, epsilon_factor: float = 1.0, ) -> float: - """ - Total Variation Distance (L1-loss) for a single sample. + """Total Variation Distance (L1-loss) for a single sample. logits: [tokens, vocab] target_probs: [tokens, vocab] """ diff --git a/pyproject.toml b/pyproject.toml index 010070e633..0ae6bdf78b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ extend-ignore = [ "D", "E501", ] # Ignore missing docstrings or line length for Jupyter notebooks -"modelopt/torch/_compress/*" = [ +"modelopt/torch/puzzletron/*" = [ "C4", "D", "E", diff --git a/setup.py b/setup.py index bd14878a51..6096e31cab 100644 --- a/setup.py +++ b/setup.py @@ -102,8 +102,8 @@ "setuptools>=80", "setuptools-scm>=8", ], - # Dependedencies for modelopt.torch._compress subpackage - "compress": [ + # Dependedencies for modelopt.torch.puzzletron subpackage + "puzzletron": [ "fire", "hydra-core==1.3.2", "immutabledict", diff --git a/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-attn-pruning.yaml similarity index 100% rename from tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml rename to tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-attn-pruning.yaml diff --git a/tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-ffn-pruning.yaml similarity index 100% rename from tests/gpu/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml rename to tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-ffn-pruning.yaml diff --git a/tests/gpu/torch/_compress/resources/configs/pruning/attn_pruning.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml similarity index 100% rename from tests/gpu/torch/_compress/resources/configs/pruning/attn_pruning.yaml rename to tests/_test_utils/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml diff --git a/tests/gpu/torch/_compress/resources/configs/pruning/ffn_pruning.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/pruning/ffn_pruning.yaml similarity index 100% rename from tests/gpu/torch/_compress/resources/configs/pruning/ffn_pruning.yaml rename to tests/_test_utils/torch/puzzletron/resources/configs/pruning/ffn_pruning.yaml diff --git a/tests/gpu/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml similarity index 100% rename from tests/gpu/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml rename to tests/_test_utils/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml diff --git a/tests/gpu/torch/_compress/resources/configs/pruning/pruning_defaults.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml similarity index 100% rename from tests/gpu/torch/_compress/resources/configs/pruning/pruning_defaults.yaml rename to tests/_test_utils/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml diff --git a/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/validate_model_defaults.yaml similarity index 76% rename from tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml rename to tests/_test_utils/torch/puzzletron/resources/configs/validate_model_defaults.yaml index 192b82c75e..1d042d75df 100644 --- a/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml +++ b/tests/_test_utils/torch/puzzletron/resources/configs/validate_model_defaults.yaml @@ -14,4 +14,4 @@ write_results: false calc_losses_on_cpu: false activations_log_dir: model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch._compress.utils.data.dataloaders.load_from_disk_fn} +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/gpu/torch/_compress/resources/configs/validate_solutions_defaults.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml similarity index 100% rename from tests/gpu/torch/_compress/resources/configs/validate_solutions_defaults.yaml rename to tests/_test_utils/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml diff --git a/tests/gpu/torch/_compress/resources/tokenizer/special_tokens_map.json b/tests/_test_utils/torch/puzzletron/resources/tokenizer/special_tokens_map.json similarity index 100% rename from tests/gpu/torch/_compress/resources/tokenizer/special_tokens_map.json rename to tests/_test_utils/torch/puzzletron/resources/tokenizer/special_tokens_map.json diff --git a/tests/gpu/torch/_compress/resources/tokenizer/tokenizer.json b/tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer.json similarity index 100% rename from tests/gpu/torch/_compress/resources/tokenizer/tokenizer.json rename to tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer.json diff --git a/tests/gpu/torch/_compress/resources/tokenizer/tokenizer_config.json b/tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer_config.json similarity index 100% rename from tests/gpu/torch/_compress/resources/tokenizer/tokenizer_config.json rename to tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer_config.json diff --git a/tests/gpu/torch/_compress/resources/tokenizer/truncate_tokenizer.py b/tests/_test_utils/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py similarity index 100% rename from tests/gpu/torch/_compress/resources/tokenizer/truncate_tokenizer.py rename to tests/_test_utils/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py diff --git a/tests/gpu/torch/_compress/compress_test_utils.py b/tests/_test_utils/torch/puzzletron/utils.py similarity index 96% rename from tests/gpu/torch/_compress/compress_test_utils.py rename to tests/_test_utils/torch/puzzletron/utils.py index 1da08602bf..6c9feecd0d 100644 --- a/tests/gpu/torch/_compress/compress_test_utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -22,14 +22,14 @@ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers def setup_test_model_and_data( project_root_path: Path, tmp_path: Path, rank: int ) -> tuple[Path, Path, Path]: """ - Setup the test model and data for the compress NAS search. + Setup the test model and data for the puzzletron NAS search. Args: project_root_path (Path): the root path of the project @@ -111,7 +111,7 @@ def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: """ Create a tokenizer for the Llama model. """ - tokenizer_path = project_root_path / "tests/gpu/torch/_compress/resources/tokenizer" + tokenizer_path = project_root_path / "tests/_test_utils/torch/puzzletron/resources/tokenizer" tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) return tokenizer diff --git a/tests/gpu/torch/_compress/conftest.py b/tests/gpu/torch/puzzletron/conftest.py similarity index 100% rename from tests/gpu/torch/_compress/conftest.py rename to tests/gpu/torch/puzzletron/conftest.py diff --git a/tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py b/tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py similarity index 90% rename from tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py rename to tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py index 7576f270b3..4b1ea0b414 100644 --- a/tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py +++ b/tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py @@ -16,12 +16,9 @@ import json from pathlib import Path -from gpu.torch._compress.compress_test_utils import ( - create_and_save_small_llama_model, - create_tokenizer, -) +from _test_utils.torch.puzzletron.utils import create_and_save_small_llama_model, create_tokenizer -from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( +from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) diff --git a/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py similarity index 86% rename from tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py rename to tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index 913bc2116c..c409da28be 100644 --- a/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -20,11 +20,11 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from gpu.torch._compress.compress_test_utils import setup_test_model_and_data +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.nas as mtn import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel +from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): @@ -43,18 +43,18 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" hydra_config_name = "Llama-3_1-8B-ffn-pruning" # # Run the mnt.convert() step # - input_model = CompressModel() + input_model = PuzzletronModel() mtn.convert( input_model, mode=[ ( - "compress", + "puzzletron", { "puzzle_dir": str(puzzle_dir), "input_model_path": str(llama_checkpoint_path), @@ -82,8 +82,6 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( dist.cleanup() - print("PYTEST SUMMARY: test_nas_convert_ffn_pruning() test has finished successfully") - def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( @@ -101,18 +99,18 @@ def _test_nas_convert_attn_pruning_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" hydra_config_name = "Llama-3_1-8B-attn-pruning" # # Run the mnt.convert() step # - input_model = CompressModel() + input_model = PuzzletronModel() mtn.convert( input_model, mode=[ ( - "compress", + "puzzletron", { "puzzle_dir": str(puzzle_dir), "input_model_path": str(llama_checkpoint_path), @@ -142,5 +140,3 @@ def _test_nas_convert_attn_pruning_multiprocess_job( assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() dist.cleanup() - - print("PYTEST SUMMARY: test_nas_convert_attn_pruning() test has finished successfully") diff --git a/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py similarity index 89% rename from tests/gpu/torch/_compress/nas/plugins/test_nas_search.py rename to tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index 1b4ed93c66..a1258c1d0b 100644 --- a/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -19,11 +19,11 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from gpu.torch._compress.compress_test_utils import setup_test_model_and_data +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.nas as mtn import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel +from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel def test_nas_search(project_root_path: Path, tmp_path: Path): @@ -42,18 +42,18 @@ def _test_nas_search_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" hydra_config_name = "Llama-3_1-8B-ffn-pruning" # # Run the mnt.convert() step # - input_model = CompressModel() + input_model = PuzzletronModel() converted_model = mtn.convert( input_model, mode=[ ( - "compress", + "puzzletron", { "puzzle_dir": str(puzzle_dir), "input_model_path": str(llama_checkpoint_path), @@ -100,5 +100,3 @@ def _test_nas_search_multiprocess_job( assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() dist.cleanup() - - print("PYTEST SUMMARY: test_nas_search() test has finished successfully") diff --git a/tests/gpu/torch/_compress/test_compress.py b/tests/gpu/torch/puzzletron/test_puzzletron.py similarity index 83% rename from tests/gpu/torch/_compress/test_compress.py rename to tests/gpu/torch/puzzletron/test_puzzletron.py index dd6e0eb5a3..faf72f7495 100644 --- a/tests/gpu/torch/_compress/test_compress.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -19,11 +19,11 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from gpu.torch._compress.compress_test_utils import setup_test_model_and_data +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.utils.distributed as dist -from modelopt.torch._compress import compress -from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( +from modelopt.torch.puzzletron import puzzletron +from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) @@ -33,21 +33,23 @@ # Note: Bypass is disabled now in the test. -def test_compress(project_root_path: Path, tmp_path: Path): +def test_puzzletron(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=min(torch.cuda.device_count(), 2), # assertions configured for atmost 2 GPUs - job=partial(_test_compress_multiprocess_job, project_root_path, tmp_path), + job=partial(_test_puzzletron_multiprocess_job, project_root_path, tmp_path), backend="nccl", ) -def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): +def _test_puzzletron_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" hydra_config_name = "Llama-3_1-8B-ffn-pruning" # Convert the Llama model to DeciLM model. @@ -59,7 +61,9 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran dist.barrier() # Compress the model using a one-click approach - compress.compress(str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path)) + puzzletron.puzzletron( + str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path) + ) # # Check assertions @@ -93,11 +97,6 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran dist.cleanup() - print( - "PYTEST SUMMARY: test_compress_model() test has finished successfully. Puzzle directory: ", - puzzle_dir, - ) - def _assert_score_pruning_activations(puzzle_dir: Path): """Assertions for the score_pruning_activations step 1.""" From 4c30bd5d73efb2d92c1f33c6d9fbeea41fa7d165 Mon Sep 17 00:00:00 2001 From: Liana Mikaelyan <45925959+LianaMikael@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:37:06 +0000 Subject: [PATCH 31/62] Add NeMo Conversion Scripts to Puzzletron (#784) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? This PR adds the scripts to convert the Puzzle model from HuggingFace to NeMo and back. This is required for running knowledge distillation. The tutorial is further updated with the instructions. ## Summary by CodeRabbit ## Release Notes * **New Features** * Added support for Puzzletron Llama-Nemotron models within the NeMo framework. * New model conversion tools enabling seamless switching between HuggingFace and NeMo formats. * Knowledge distillation workflow to restore model quality after compression. * **Documentation** * Added knowledge distillation process guide with step-by-step instructions. * **Dependencies** * Added lm-eval as an optional dependency. ✏️ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: Liana Mikaelyan --- examples/puzzletron/README.md | 19 + .../nemo_export/convert_hf_to_nemo.py | 98 ++ .../nemo_export/convert_nemo_to_hf.py | 96 ++ examples/puzzletron/requirements.txt | 1 + .../puzzletron/export/MCore/llama_nemotron.py | 1015 +++++++++++++++++ .../export/MCore/llama_nemotron_utils.py | 729 ++++++++++++ .../MCore/puzzletron_hf_config_utils.py | 142 +++ .../export/MCore/puzzletron_layer_specs.py | 928 +++++++++++++++ 8 files changed, 3028 insertions(+) create mode 100644 examples/puzzletron/nemo_export/convert_hf_to_nemo.py create mode 100644 examples/puzzletron/nemo_export/convert_nemo_to_hf.py create mode 100644 examples/puzzletron/requirements.txt create mode 100644 modelopt/torch/puzzletron/export/MCore/llama_nemotron.py create mode 100644 modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py create mode 100644 modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py create mode 100644 modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index e3a909d224..f16162083d 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -17,6 +17,7 @@ In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/ ```bash pip install -e .[hf,puzzletron] +pip install -r requirements.txt ``` - For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use s single GPU. @@ -231,6 +232,24 @@ vllm bench latency --model path/to/model --load-format safetensors --trust-remot vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --load-format safetensors --trust-remote-code ``` +## Knowledge Distillation + +To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. For this, we will use [NeMo framework](https://github.com/NVIDIA-NeMo/NeMo) with the [nemo:25.07](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.07) container. + +First, convert the HF model to NeMo format: + +```bash +python -m nemo_export/convert_hf_to_nemo --input-ckpt-path path/to/HF-model --output-ckpt-path path/to/save/model-nemo +``` + +Now you can utilize all the training features available in NeMo, including distillation. Please refer to the [NeMo distillation documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html). + +[Optional] Once distillation is complete, you can convert the distilled model back to the HuggingFace format. + +```bash +python -m nemo_export/convert_nemo_to_hf --input-ckpt-path path/to/nemo-model --output-ckpt-path path/to/save/model-HF +``` + ## Advanced Usage Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios. diff --git a/examples/puzzletron/nemo_export/convert_hf_to_nemo.py b/examples/puzzletron/nemo_export/convert_hf_to_nemo.py new file mode 100644 index 0000000000..0cf16b4486 --- /dev/null +++ b/examples/puzzletron/nemo_export/convert_hf_to_nemo.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path +from typing import Any + +from nemo.collections import llm + +from modelopt.torch.puzzletron.export.MCore.llama_nemotron import ( + PuzzletronLlamaNemotronModel, + PuzzletronNemotronModelConfig, +) + + +def convert_model( + hf_model_path_local: str, output_path_nemo_local: str, overwrite: bool = False +) -> Any: + """Convert a Puzzletron HuggingFace model to NeMo format. + + Args: + hf_model_path_local: Path to the input Puzzletron HuggingFace model directory + output_path_nemo_local: Path where the converted Puzzletron NeMo model will be saved + overwrite: Whether to overwrite existing output directory + """ + + model = PuzzletronLlamaNemotronModel(config=PuzzletronNemotronModelConfig) + # NOTE: API call to import_ckpt is here: https://github.com/NVIDIA-NeMo/NeMo/blob/294ddff187f68c055d87ffe9400e65975b38693d/nemo/collections/llm/api.py#L888 + print( + f"calling import_ckpt with model: {model}, " + f"source: {hf_model_path_local}, " + f"output_path: {output_path_nemo_local}, " + f"overwrite: {overwrite}" + ) + nemo2_path = llm.import_ckpt( + model=model, + source="hf://" + hf_model_path_local, + output_path=Path(output_path_nemo_local), + overwrite=overwrite, + ) + + print(f"Model saved to {nemo2_path}") + return nemo2_path + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert Puzzletron HuggingFace model to NeMo format" + ) + parser.add_argument( + "--input-ckpt-path", + "-i", + type=str, + required=True, + help="Path to the input Puzzletron HuggingFace model directory", + ) + parser.add_argument( + "--output-ckpt-path", + "-o", + type=str, + required=True, + help="Path where the converted Puzzletron NeMo model will be saved", + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Whether to overwrite existing output directory (default: False)", + ) + + args = parser.parse_args() + + # Validate input path + if not os.path.exists(args.input_ckpt_path): + raise FileNotFoundError(f"Input model path does not exist: {args.input_ckpt_path}") + + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True) + + print(f"Converting model from {args.input_ckpt_path} to {args.output_ckpt_path}") + convert_model(args.input_ckpt_path, args.output_ckpt_path, args.overwrite) + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/nemo_export/convert_nemo_to_hf.py b/examples/puzzletron/nemo_export/convert_nemo_to_hf.py new file mode 100644 index 0000000000..4645ae5b43 --- /dev/null +++ b/examples/puzzletron/nemo_export/convert_nemo_to_hf.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path +from typing import Any + +from nemo.collections import llm + +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import copy_deci_lm_hf_code + + +def convert_model( + nemo_model_path_local: str, output_path_hf_local: str, overwrite: bool = False +) -> Any: + """Convert a NeMo model to HuggingFace format. + + Args: + nemo_model_path_local: Path to the input NeMo model file (.nemo) + output_path_hf_local: Path where the converted HuggingFace model will be saved + overwrite: Whether to overwrite existing output directory + """ + + # NOTE: API call to export_ckpt is here: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/llm/api.py#L987 + print( + f"calling export_ckpt with path: {nemo_model_path_local}, " + f"target: hf, output_path: {output_path_hf_local}, " + f"target_model_name: PuzzletronLlamaNemotronModel, " + f"overwrite: {overwrite}" + ) + + hf_path = llm.export_ckpt( + path=nemo_model_path_local, + target="hf", + output_path=Path(output_path_hf_local), + target_model_name="PuzzletronLlamaNemotronModel", + overwrite=overwrite, + ) + + copy_deci_lm_hf_code(hf_path) + + print(f"Model saved to {hf_path}") + return hf_path + + +def main() -> None: + parser = argparse.ArgumentParser(description="Convert NeMo model to HuggingFace format") + parser.add_argument( + "--input-ckpt-path", + "-i", + type=str, + required=True, + help="Path to the input NeMo model checkpoint", + ) + parser.add_argument( + "--output-ckpt-path", + "-o", + type=str, + required=True, + help="Path where the converted Puzzletron HuggingFace model will be saved", + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Whether to overwrite existing output directory (default: False)", + ) + + args = parser.parse_args() + + # Validate input path + if not os.path.exists(args.input_ckpt_path): + raise FileNotFoundError(f"Input model path does not exist: {args.input_ckpt_path}") + + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True) + + print(f"Converting model from {args.input_ckpt_path} to {args.output_ckpt_path}") + convert_model(args.input_ckpt_path, args.output_ckpt_path, args.overwrite) + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/requirements.txt b/examples/puzzletron/requirements.txt new file mode 100644 index 0000000000..fe63c413bc --- /dev/null +++ b/examples/puzzletron/requirements.txt @@ -0,0 +1 @@ +lm-eval==0.4.9 diff --git a/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py b/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py new file mode 100644 index 0000000000..d4292322f7 --- /dev/null +++ b/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py @@ -0,0 +1,1015 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# based on https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/llm/gpt/model/llama_nemotron.py + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, Callable, Dict, Optional, Union + +import torch +import torch.nn.functional as F +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel, torch_dtype_from_mcore_config +from nemo.collections.llm.gpt.model.llama import ( + Llama3Config, + Llama31Config, + Llama31Config70B, + LlamaConfig, + apply_rope_scaling, +) +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown +from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME +from nemo.lightning.io.pl import ckpt_to_weights_subdir +from nemo.lightning.io.state import TransformFns +from nemo.lightning.pytorch.utils import dtype_from_hf, dtype_from_str +from nemo.utils import logging +from nemo.utils.import_utils import safe_import +from torch import nn + +from modelopt.torch.puzzletron.tools.logger import mprint + +# from nemo.collections.llm.gpt.model.llama_nemotron import Llama33NemotronSuper49BConfig + + +_, HAVE_TE = safe_import("transformer_engine") +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.transformer.heterogeneous.heterogeneous_config import ( + HeterogeneousTransformerConfig, +) +from megatron.core.transformer.spec_utils import ModuleSpec + +if TYPE_CHECKING: + from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + from peft import AutoPeftModelForCausalLM, PeftConfig + from transformers import GenerationConfig, LlamaForCausalLM + from transformers import LlamaConfig as HFLlamaConfig + + from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig + +from modelopt.torch.puzzletron.export.MCore.llama_nemotron_utils import ( + _build_puzzletron_mappings_and_transforms, + _config_to_dict, + convert_attention_config_from_cfg_object, + convert_mlp_config_from_cfg_object, + convert_nemo_config_to_hf_decilm_config, + dtype_from_dict, + merge_qkv_for_puzzletron, + split_qkv_for_puzzletron, +) +from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( + PuzzletronHeterogeneousTransformerConfig, + get_gpt_heterogeneous_layer_spec_puzzletron, +) + + +def heterogeneous_layer_spec_puzzletron( + config: PuzzletronHeterogeneousTransformerConfig, +) -> ModuleSpec: + return get_gpt_heterogeneous_layer_spec_puzzletron(config, use_transformer_engine=HAVE_TE) + + +# Refactored to inherit directly from GPTConfig instead of Llama31Config70B +# This makes it easier to understand what attributes are set through the hierarchy +@dataclass +class PuzzletronNemotronModelConfig(GPTConfig, PuzzletronHeterogeneousTransformerConfig): + """Configuration for Puzzletron Nemotron models. + + DESIGN RATIONALE: + ================ + Refactored from original inheritance (Llama31Config70B + PuzzletronHeterogeneousTransformerConfig) + to explicit attribute definition for clarity and maintainability. Maintains identical behavior + to the original Llama hierarchy while enabling future flexibility. + + ATTRIBUTE ORGANIZATION: + ====================== + Explicitly defines attributes from the Llama hierarchy: + Llama31Config70B → Llama31Config → Llama3Config → LlamaConfig → GPTConfig + + FUTURE DEVELOPMENT: + ================== + Attributes can be freely modified/removed for future Puzzletron models. + In this case the tests in test_puzzletron_nemotron_config_inheritance.py will need to be updated. + Current explicit definition is for clarity during transition period. + """ + + # Override attributes from PuzzletronHeterogeneousTransformerConfig with Llama hierarchy values + # These ensure we maintain the same behavior as the original Llama31Config70B inheritance + + # ===== LlamaConfig attributes ===== + # Core model architecture + # NOTE: Default is F.silu, but this is overridden during instantiation to match all blocks + # See instantiate_nemo_config_from_adapted_dict() which enforces same activation across blocks + activation_func: Callable = F.silu + normalization: str = "RMSNorm" + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + # seq_length: int = 4096 # (will be overridden by Llama31Config70B) + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + share_embeddings_and_output_weights: bool = False + # Fusion settings + bias_activation_fusion: bool = True + masked_softmax_fusion: bool = True + persist_layer_norm: bool = True + bias_dropout_fusion: bool = True + apply_rope_fusion: bool = True + use_transformer_engine_op_fuser: Optional[bool] = None + + # ===== Llama3Config attributes ===== + num_query_groups: int = 8 + # init_method_std: float = 0.01 # (will be overridden by Llama31Config) + layernorm_epsilon: float = 1.0e-05 + rotary_percent: float = 1.0 + + # ===== Llama31Config attributes ===== + scale_factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + old_context_len: int = 8192 + init_method_std: float = 0.02 # (overrides Llama3Config) + + # ===== Llama31Config70B attributes ===== + # Core model architecture (70B-specific) + rotary_base: int = 500_000 + seq_length: int = 131072 # (overrides LlamaConfig) + num_layers: int = 80 # + hidden_size: int = 8192 # + ffn_hidden_size: int = 28672 # + num_attention_heads: int = 64 # + kv_channels: int = 128 # (derived from hidden_size // num_attention_heads) + make_vocab_size_divisible_by: int = 128 # + + # ===== PuzzletronHeterogeneousTransformerConfig attributes ===== + # Actual new PuzzleNemotronModelConfig attributes + heterogeneous_layers_config_path: Optional[str] = None + heterogeneous_layers_config_encoded_json: Optional[str] = None + transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = ( + heterogeneous_layer_spec_puzzletron + ) + + # HF-specific metadata for lossless round-trip conversion (HF → NeMo → HF) + # Stores HF config fields that don't have direct NeMo equivalents + source_hf_config_metadata: Optional[Dict[str, Any]] = None + + # NOTE: How activation_func is handled for Puzzletron models + # ============================================================== + # Puzzletron models can define activation functions per-block, but MCore's validation + # only checks the global activation_func (not per-block activations). + # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_config.py#L1043-L1061 + # + # Current approach (enforced in instantiate_nemo_config_from_adapted_dict): + # - All blocks must use the SAME activation function (None allowed for no-op blocks) + # - The global activation_func is set to match the blocks' shared activation + # - This ensures MCore's global validation passes correctly + # + # Rationale: + # 1. MCore validates global activation_func during __post_init__() (lines 1043-1061) + # 2. NeMo calls __post_init__() AGAIN during trainer.strategy.connect(model) + # See: https://github.com/NVIDIA/NeMo/blob/2e19aebd8c8fa9ff7ce9b5076ce130404713443c/nemo/lightning/_strategy_lib.py#L172-L175 + # 3. At runtime, MCore uses per-block activations from get_config_for_layer() + # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_block.py#L308-L319 + # + # For heterogeneous activations across blocks, MCore would need to update their + # validation logic to support per-block validation (e.g., in get_config_for_layer() or MLP.__init__) + + # ===== Llama31Config method ===== + def configure_model( + self, tokenizer, pre_process=None, post_process=None, vp_stage=None + ) -> "MCoreGPTModel": + """Configure and instantiate a Megatron Core Llama 3.1 model. + + NOTE: This method is originally from Llama31Config and is explicitly included here + for consistency and clarity. It maintains the same behavior as the original + Llama hierarchy inheritance approach. + + Extends the base configuration with Llama 3.1 specific RoPE scaling. + This method applies RoPE scaling for extended context length support. + """ + model = super().configure_model(tokenizer, pre_process, post_process, vp_stage) + # Apply rope scaling for Llama3.1 model + model.rotary_pos_emb.inv_freq = apply_rope_scaling( + model.rotary_pos_emb.inv_freq, + factor=self.scale_factor, + low_freq_factor=self.low_freq_factor, + high_freq_factor=self.high_freq_factor, + old_context_len=self.old_context_len, + ) + return model + + @classmethod + def from_dict_with_preprocessing(cls, config_dict): + # Potentially adapt the config_dict before instantiation + instance = cls(**config_dict) + # Potentially adapt the config after instantiation + return instance + + # static method + @staticmethod + def create_adapted_config_dict_from_puzzletron_config(cfg): + # TODO: consider doing do this without conversion to dictionary in the future (instead have an adapted config object) + # Create an empty config object of the same class as cfg + adapted_cfg_dict = dict() + orig_cfg_dict = vars(cfg) + + # Extract first set of values from the original config + adapted_cfg_dict["head_dim"] = orig_cfg_dict["head_dim"] + adapted_cfg_dict["num_attention_heads"] = orig_cfg_dict["num_attention_heads"] + # Handle rope_scaling - can be None, missing, or a dict + adapted_cfg_dict["rope_scaling"] = orig_cfg_dict.get("rope_scaling") or {} + + block_conf = { + "block_configs": [ + { + "attention": convert_attention_config_from_cfg_object( + orig_cfg_dict["block_configs"][i].attention, + adapted_cfg_dict["num_attention_heads"], + adapted_cfg_dict["head_dim"], + ), + "mlp": { + **convert_mlp_config_from_cfg_object( + orig_cfg_dict["block_configs"][i].ffn, + ( + orig_cfg_dict["block_configs"][i].parallel_blocks + if hasattr(orig_cfg_dict["block_configs"][i], "parallel_blocks") + else None + ), + ), + # Store the per-block activation function as a string (for JSON serialization) + "hidden_act": ( + orig_cfg_dict["block_configs"][i].ffn.hidden_act + if not ( + orig_cfg_dict["block_configs"][i].ffn.no_op + or orig_cfg_dict["block_configs"][i].ffn.replace_with_linear + ) + else None + ), + }, + } + for i in range(len(orig_cfg_dict["block_configs"])) + ] + } + if orig_cfg_dict["o_proj_bias"] != orig_cfg_dict["attention_bias"]: + raise NotImplementedError("o_proj_bias is not fully supported") + if orig_cfg_dict["position_embedding_type"] not in ["rope", "yarn"]: + # this one is not supported by MCore + raise ValueError( + f"only rope and yarn are supported, got {orig_cfg_dict['position_embedding_type']}" + ) + + # Handle dtype (new format uses 'dtype', old format uses 'torch_dtype') + # Check 'dtype' first, then fall back to 'torch_dtype' + if "dtype" in orig_cfg_dict and orig_cfg_dict["dtype"] is not None: + mprint(f"DEBUG: dtype found in config: {orig_cfg_dict['dtype']}") + adapted_cfg_dict["torch_dtype"] = orig_cfg_dict["dtype"] + elif "torch_dtype" in orig_cfg_dict and orig_cfg_dict["torch_dtype"] is not None: + mprint(f"DEBUG: torch_dtype found in config: {orig_cfg_dict['torch_dtype']}") + adapted_cfg_dict["torch_dtype"] = orig_cfg_dict["torch_dtype"] + else: + mprint( + f"WARNING: neither dtype nor torch_dtype found in config (or both are None), setting to bfloat16" + ) + adapted_cfg_dict["torch_dtype"] = "bfloat16" + + # TODO: check how config keys such as position_embedding_type are handled (since they're not passed to the constructor) + adapted_cfg_dict["heterogeneous_layers_config_path"] = None + adapted_cfg_dict["block_configs"] = block_conf["block_configs"] + adapted_cfg_dict["heterogeneous_layers_config_encoded_json"] = json.dumps( + block_conf, ensure_ascii=False + ) + adapted_cfg_dict["transformer_layer_spec"] = heterogeneous_layer_spec_puzzletron + adapted_cfg_dict["vocab_size"] = orig_cfg_dict["vocab_size"] + adapted_cfg_dict["num_layers"] = len(orig_cfg_dict["block_configs"]) + adapted_cfg_dict["hidden_size"] = orig_cfg_dict["hidden_size"] + # adapted_cfg_dict['num_attention_heads'] = cfg["num_attention_heads"] + adapted_cfg_dict["kv_channels"] = adapted_cfg_dict["head_dim"] + adapted_cfg_dict["scale_factor"] = float( + adapted_cfg_dict["rope_scaling"].get("factor", 8.0) + ) + adapted_cfg_dict["rotary_base"] = int(orig_cfg_dict.get("rope_theta", 500_000)) + adapted_cfg_dict["seq_length"] = int(orig_cfg_dict.get("max_position_embeddings", 131072)) + adapted_cfg_dict["init_method_std"] = float(orig_cfg_dict.get("initializer_range", 0.02)) + adapted_cfg_dict["layernorm_epsilon"] = float(orig_cfg_dict.get("rms_norm_eps", 1e-5)) + adapted_cfg_dict["share_embeddings_and_output_weights"] = bool( + orig_cfg_dict.get("tie_word_embeddings", False) + ) + # adapted_cfg_dict["make_vocab_size_divisible_by"] = 128 + + # Preserve HF-specific config fields that don't have NeMo equivalents + # This enables lossless round-trip conversion HF → NeMo → HF + source_hf_config_metadata = {} + + # eos_token_id: HF can have multiple EOS tokens [128001, 128008, 128009] + # but NeMo tokenizer only supports single eos_id (uses the last one) + if "eos_token_id" in orig_cfg_dict: + source_hf_config_metadata["eos_token_id"] = orig_cfg_dict["eos_token_id"] + + # auto_map: HF-specific field for custom model class loading via trust_remote_code + # Not relevant to NeMo but needed for HF model.from_pretrained() to work + if "auto_map" in orig_cfg_dict: + source_hf_config_metadata["auto_map"] = orig_cfg_dict["auto_map"] + + # dtype: HF uses 'dtype' field, NeMo uses 'torch_dtype', preserve both + if "dtype" in orig_cfg_dict: + source_hf_config_metadata["dtype"] = orig_cfg_dict["dtype"] + + # Store as direct config attribute (will be serialized by NeMo automatically) + adapted_cfg_dict["source_hf_config_metadata"] = ( + source_hf_config_metadata if source_hf_config_metadata else None + ) + + return adapted_cfg_dict + + +class PuzzletronLlamaNemotronModel(GPTModel): + """Llama-Nemotron model implementation based on the GPT model architecture. + + This class provides a high-level interface for Llama-Nemotron models, + implementing the specific architecture and settings needed for Llama-Nemotron models. + """ + + def __init__( + self, + config: Annotated[ + Optional[PuzzletronNemotronModelConfig] | type[PuzzletronNemotronModelConfig], + Config[PuzzletronNemotronModelConfig], + ] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__( + config or PuzzletronNemotronModelConfig(), + optim=optim, + tokenizer=tokenizer, + model_transform=model_transform, + ) + + +def instantiate_nemo_config_from_adapted_dict( + adapted_cfg_dict: dict, + generation_config: Optional["GenerationConfig"] = None, +) -> PuzzletronNemotronModelConfig: + """ + Instantiate PuzzletronNemotronModelConfig from adapted config dict. + + This function is shared by the importer and tests to ensure consistency. + + Args: + adapted_cfg_dict: Dict created by create_adapted_config_dict_from_puzzletron_config + generation_config: Optional generation config to attach + + Returns: + PuzzletronNemotronModelConfig instance + """ + + # Helper function for vocab size divisibility + def make_vocab_size_divisible_by(vocab_size: int) -> int: + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + # Keys used for PuzzletronNemotronModelConfig instantiation + INSTANTIATION_KEYS = { + "heterogeneous_layers_config_encoded_json", + "transformer_layer_spec", + "num_layers", + "hidden_size", + "num_attention_heads", + "kv_channels", + "scale_factor", + "init_method_std", + "layernorm_epsilon", + "seq_length", + "rotary_base", + "vocab_size", + "share_embeddings_and_output_weights", + "source_hf_config_metadata", + } + + # Keys that are metadata or derived (not directly passed to constructor) + metadata_keys = set(adapted_cfg_dict.keys()) - INSTANTIATION_KEYS + + mprint(f"DEBUG: Keys used for instantiation: {sorted(INSTANTIATION_KEYS)}") + mprint(f"DEBUG: Metadata keys (not used for direct instantiation): {sorted(metadata_keys)}") + for key in sorted(metadata_keys): + value = adapted_cfg_dict[key] + if isinstance(value, (list, dict)): + mprint(f" - {key}: {type(value).__name__} with {len(value)} items") + elif callable(value): + mprint(f" - {key}: {value.__name__ if hasattr(value, '__name__') else 'callable'}") + else: + mprint(f" - {key}: {value}") + + model_dtype = dtype_from_dict(adapted_cfg_dict) + + # Determine the unique activation_func from all blocks + # MCore validates the global activation_func, so we need to set it to match all blocks + heterogeneous_config = json.loads(adapted_cfg_dict["heterogeneous_layers_config_encoded_json"]) + block_list = heterogeneous_config.get("block_configs", []) + + # Assert that block_configs exists and is not empty + assert block_list, ( + "No block_configs found in heterogeneous_layers_config_encoded_json. " + "The JSON structure must contain a 'block_configs' list with at least one block." + ) + + activation_funcs = [] + + for i, block in enumerate(block_list): + # Extract hidden_act from MLP config (if present) + if "mlp" in block and "hidden_act" in block["mlp"]: + hidden_act_str = block["mlp"]["hidden_act"] + + # Track None/null values (used for no-op blocks) + if hidden_act_str is None: + activation_funcs.append(None) + continue + + # For now, only support silu and gelu activations + # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_config.py#L1043-L1048 + if hidden_act_str == "silu": + activation_funcs.append(F.silu) + elif hidden_act_str == "gelu": + activation_funcs.append(F.gelu) + else: + raise NotImplementedError( + f"Unsupported activation function: '{hidden_act_str}' in block {i}. " + f"Only 'silu', 'gelu', and None/null are currently supported. " + f"MCore's bias_activation_fusion only validates these activation functions." + ) + # If no hidden_act key or no MLP, we treat it as None + else: + activation_funcs.append(None) + + # Separate None and not-None activations + not_none_activations = [f for f in activation_funcs if f is not None] + + # Check that all not-None activation functions are the same + unique_not_none = {id(f) for f in not_none_activations} + + if len(unique_not_none) == 0: + # No activation functions found (all blocks are no-op or have None) + # Default to F.silu to pass MCore validation + global_activation_func = F.silu + mprint( + "WARNING: No not-None activation functions found in blocks, defaulting global activation_func to F.silu" + ) + elif len(unique_not_none) == 1: + # All not-None blocks use the same activation function (safe) + global_activation_func = not_none_activations[0] + func_name = ( + global_activation_func.__name__ + if hasattr(global_activation_func, "__name__") + else str(global_activation_func) + ) + none_count = activation_funcs.count(None) + total_count = len(activation_funcs) + mprint( + f"INFO: All {total_count - none_count} not-None blocks use the same activation function: {func_name} ({none_count} None/no-op blocks)" + ) + else: + # Multiple different not-None activation functions found (currently not supported/tested) + func_names = [f.__name__ if hasattr(f, "__name__") else "None" for f in activation_funcs] + unique_func_names = set(f.__name__ for f in not_none_activations) + assert False, ( + f"Puzzletron blocks must all use the same activation function (None allowed for no-op blocks). " + f"Found {len(unique_not_none)} different not-None activation functions across blocks: {unique_func_names}. " + f"Block activations: {func_names}. " + f"MCore's validation only checks the global activation_func, which would not match heterogeneous activations. " + f"Either make all blocks use the same activation, or update MCore to support per-block validation." + ) + + return PuzzletronNemotronModelConfig( + heterogeneous_layers_config_encoded_json=adapted_cfg_dict[ + "heterogeneous_layers_config_encoded_json" + ], + heterogeneous_layers_config_path=None, # We directly load the block config as json + transformer_layer_spec=adapted_cfg_dict["transformer_layer_spec"], + activation_func=global_activation_func, # Set to match all blocks + num_layers=adapted_cfg_dict["num_layers"], + hidden_size=adapted_cfg_dict["hidden_size"], + num_attention_heads=adapted_cfg_dict["num_attention_heads"], + kv_channels=adapted_cfg_dict["kv_channels"], + scale_factor=adapted_cfg_dict["scale_factor"], + init_method_std=adapted_cfg_dict["init_method_std"], + layernorm_epsilon=adapted_cfg_dict["layernorm_epsilon"], + seq_length=adapted_cfg_dict["seq_length"], + rotary_base=adapted_cfg_dict["rotary_base"], + make_vocab_size_divisible_by=make_vocab_size_divisible_by(adapted_cfg_dict["vocab_size"]), + vocab_size=adapted_cfg_dict["vocab_size"], + share_embeddings_and_output_weights=adapted_cfg_dict["share_embeddings_and_output_weights"], + # HF-specific metadata for lossless round-trip conversion + source_hf_config_metadata=adapted_cfg_dict.get("source_hf_config_metadata"), + fp16=(model_dtype == torch.float16), + bf16=(model_dtype == torch.bfloat16), + params_dtype=model_dtype, + generation_config=generation_config, + ) + + +@io.model_importer(PuzzletronLlamaNemotronModel, "hf") +class PuzzletronHFLlamaNemotronImporter( + io.ModelConnector["LlamaForCausalLM", PuzzletronLlamaNemotronModel] +): + """Importer for converting Hugging Face Llama-Nemotron models to NeMo format. + + This class handles the conversion of Hugging Face's LlamaForCausalLM models + to NeMo's PuzzletronLlamaNemotronModel format, including weight mapping and configuration translation. + """ + + # Base mapping using standard LLaMA weight names + # Layernorm wildcards are replaced with per-layer mappings in convert_state() + # TODO: MoE and Mamba layer conversions have not been tested yet + default_mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + def init(self) -> PuzzletronLlamaNemotronModel: + """Initialize a NeMo LlamaModel instance. + + Returns: + LlamaModel: Initialized NeMo Llama model with the appropriate configuration + and tokenizer. + """ + config = self.config + mprint(f"DEBUG: NeMo config dtype settings:") + mprint(f" - config.bf16: {config.bf16}") + mprint(f" - config.fp16: {config.fp16}") + mprint(f" - config.params_dtype: {config.params_dtype}") + return PuzzletronLlamaNemotronModel(config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + """Apply the conversion from HF to NeMo format. + + Args: + output_path: Path where the converted model will be saved + + Returns: + Path: Path to the saved NeMo model + """ + from transformers import AutoModelForCausalLM + + logging.info(f"Load Puzzletron HF model {str(self)}") + source = AutoModelForCausalLM.from_pretrained( + str(self), trust_remote_code=True, torch_dtype="auto" + ) + logging.info("Initialize NeMo Puzzletron Llama Nemotron model") + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + mprint( + f"Converted Llama-Nemotron model to Nemo, model saved to {output_path} in {source.dtype}." + ) + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source: Any, target: Any) -> Any: + """Convert state dict from HF format to NeMo format. + + Maps the weights from the HF model to the NeMo model according to + the appropriate mapping scheme. + + Args: + source: Source HF model + target: Target NeMo model + + Returns: + The result of applying the transforms + """ + mapping = self.default_mapping.copy() + + if target.config.normalization == "LayerNorm": + mapping["model.norm.bias"] = "decoder.final_layernorm.bias" + if getattr(source.config, "tie_word_embeddings", False): + del mapping["lm_head.weight"] + + # Puzzletron models must have block_configs for heterogeneous layer support + assert hasattr(source.config, "block_configs"), "Puzzletron models must have block_configs" + + # Build per-layer specific mappings for heterogeneous support + attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs = ( + _build_puzzletron_mappings_and_transforms(source.config) + ) + + # Remove layernorm wildcards from default_mapping - these will be replaced with + # specific per-layer mappings based on each layer's architecture. + for pattern in [ + "model.layers.*.input_layernorm.weight", + "model.layers.*.post_attention_layernorm.weight", + ]: + if pattern in mapping: + del mapping[pattern] + + # Add all layer-specific mappings + mapping.update(**attn_mapping) + mapping.update(**ffn_mapping) + mapping.update(**mamba_mapping) + mapping.update(**moe_mapping) + + # Create transforms from specification + transforms = [] + + # Helper to create merge_qkv closure with proper layer index capture + def make_merge_qkv_fn(layer_idx): + def merge_qkv_fn(ctx, q, k, v): + return merge_qkv_for_puzzletron(ctx, q, k, v, idx=layer_idx) + + return merge_qkv_fn + + for spec in transform_specs: + if spec["transform_function"] == "merge_qkv_for_puzzletron": + # Fixed: proper closure to avoid variable capture issues + layer_idx = spec["kwargs"]["idx"] + transforms.append( + io.state_transform( + source_key=spec["source_key"], + target_key=spec["target_key"], + fn=make_merge_qkv_fn(layer_idx), + ) + ) + elif spec["transform_function"] == "merge_fc1_for_moe": + transforms.append( + io.state_transform( + source_key=spec["source_key"], + target_key=spec["target_key"], + fn=TransformFns.merge_fc1, + ) + ) + + # Add standard FC1 merge transform + transforms.append( + io.state_transform( + source_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + target_key="decoder.layers.*.mlp.linear_fc1.weight", + fn=TransformFns.merge_fc1, + ) + ) + return io.apply_transforms(source, target, mapping=mapping, transforms=transforms) + + @property + def tokenizer(self) -> "AutoTokenizer": + """Get the tokenizer for the HF model. + + Returns: + AutoTokenizer: Tokenizer instance initialized from the HF model's tokenizer + """ + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(self.save_hf_tokenizer_assets(str(self)), trust_remote_code=True) + + @property + def config(self) -> PuzzletronNemotronModelConfig: + """Create a NeMo LlamaNemotronConfig from the HF model config. + + Translates the HF configuration parameters to the equivalent NeMo + configuration. + + Returns: + PuzzletronNemotronModelConfig: Puzzletron NeMo configuration for Llama models + """ + from transformers import AutoConfig, GenerationConfig + + source = AutoConfig.from_pretrained(str(self), trust_remote_code=True) + + # Validate that this is a proper Puzzletron-Nemotron checkpoint + assert getattr(source, "rope_scaling", None), ( + "Llama-Nemotron model should have rope scaling" + ) + assert getattr(source, "block_configs", None) is not None, ( + "Puzzletron-Nemotron model should be heterogeneous and have block configs" + ) + + adapted_cfg_dict = ( + PuzzletronNemotronModelConfig.create_adapted_config_dict_from_puzzletron_config(source) + ) + + try: + generation_config = GenerationConfig.from_pretrained(str(self)) + except Exception: + generation_config = None + + output = instantiate_nemo_config_from_adapted_dict( + adapted_cfg_dict, generation_config=generation_config + ) + return output + + +@io.model_exporter(PuzzletronLlamaNemotronModel, "hf") +class PuzzletronHFLlamaNemotronExporter( + io.ModelConnector[PuzzletronLlamaNemotronModel, "LlamaForCausalLM"] +): + """Exporter for converting NeMo Puzzletron Llama-Nemotron models to Hugging Face format. + + This class handles the conversion of NeMo's PuzzletronLlamaNemotronModel to Hugging Face's + LlamaForCausalLM format, including weight mapping and configuration translation. + It supports heterogeneous model architectures with Puzzletron-specific configurations. + + The exporter performs the following key operations: + 1. Initializes a Hugging Face model with appropriate configuration + 2. Maps weights from NeMo format to Hugging Face format + 3. Handles special cases for heterogeneous architectures with Mamba, MoE, and other custom layers + 4. Saves the converted model and tokenizer to the specified output path + + Attributes: + tokenizer: The tokenizer associated with the NeMo model + config: The configuration for the Hugging Face model + + Methods: + init: Initialize a Hugging Face model instance + apply: Convert and save the model to Hugging Face format + convert_state: Convert model weights from NeMo to Hugging Face format + """ + + # Base mapping for NeMo -> HF conversion (reversed from importer) + # Layernorm wildcards are replaced with per-layer mappings in convert_state() + default_mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + @property + def config(self) -> "DeciLMConfig": + """Create a HF DeciLMConfig from the NeMo model config. + + This method constructs a DeciLMConfig for Puzzletron models by parsing the + heterogeneous_layers_config_encoded_json from the NeMo config and mapping + the fields to the HF DeciLM format. + + Returns: + DeciLMConfig: HF configuration for Puzzletron DeciLM models + """ + # Load the NeMo config + source_config = io.load_context(str(self), subpath="model.config") + + # Get preserved HF config metadata (stored as direct attribute) + # This enables lossless round-trip conversion HF → NeMo → HF + source_hf_config_metadata = getattr(source_config, "source_hf_config_metadata", None) or {} + + # Get EOS token ID(s) - prefer preserved value from source HF config metadata + # (HF supports multiple EOS tokens, NeMo tokenizer only has single eos_id) + eos_token_id = source_hf_config_metadata.get("eos_token_id", self.tokenizer.eos_id) + + # Use the shared conversion function + return convert_nemo_config_to_hf_decilm_config( + nemo_config=source_config, + vocab_size=self.tokenizer.vocab_size, + eos_token_id=eos_token_id, + bos_token_id=self.tokenizer.bos_id, + pad_token_id=getattr(self.tokenizer, "pad_id", None), + ) + + def init(self, dtype=torch.bfloat16, from_config=False, model_name=None) -> "LlamaForCausalLM": + """Initialize a Hugging Face LlamaForCausalLM model instance. + + This method creates a new Hugging Face model instance with the appropriate configuration + and data type. Puzzletron models always use from_config=True and create a DeciLMForCausalLM. + + Args: + dtype (torch.dtype, optional): Data type for model parameters. Defaults to torch.bfloat16. + from_config (bool, optional): Whether to initialize from config or load from pretrained. + For Puzzletron models, this should always be True. Defaults to False. + model_name (str, optional): Name of the pretrained model to load. Not used for Puzzletron + models since we generate the config dynamically. Defaults to None. + + Returns: + DeciLMForCausalLM: Initialized Hugging Face DeciLM model instance + + Raises: + ValueError: If model_name is provided (not supported for Puzzletron models) + """ + from transformers.modeling_utils import no_init_weights + + from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMForCausalLM, + ) + + with no_init_weights(): + if from_config: + # Puzzletron models: create DeciLMForCausalLM from self.config property + model = DeciLMForCausalLM(self.config) + model = model.to(dtype=dtype) + return model + else: + # Puzzletron models don't support loading from pretrained HF model cards + raise ValueError( + "Puzzletron models do not have official HF model cards. " + "Use from_config=True to create models from NeMo config." + ) + + def apply(self, output_path: Path, target_model_name=None) -> Path: + """Convert and save a NeMo Puzzletron Llama-Nemotron model to Hugging Face format. + + This method performs the complete conversion process: + 1. Loads the NeMo model checkpoint + 2. Creates the Hugging Face model from config + 3. Converts and transfers the weights + 4. Saves the converted model and tokenizer + + Args: + output_path (Path): Directory path where the converted model will be saved + target_model_name (str, optional): Not used for Puzzletron models. Kept for API compatibility. + + Returns: + Path: Path to the saved Hugging Face model directory + """ + logging.info("Loading Puzzletron Llama-Nemotron NeMo checkpoint..") + source, _ = self.nemo_load(str(self)) + + # Puzzletron models always use from_config=True to generate DeciLMConfig dynamically + target = self.init( + torch_dtype_from_mcore_config(source.config), + from_config=True, + model_name=None, + ) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.tokenizer.save_pretrained(output_path) + + # Copy custom Python files needed for Puzzletron models + from modelopt.torch.puzzletron.export.MCore.llama_nemotron_utils import ( + copy_puzzletron_python_files_to_decilm_checkpoint, + embed_chat_template_in_tokenizer_config, + ) + + copy_puzzletron_python_files_to_decilm_checkpoint(str(output_path)) + + # Fix tokenizer: embed chat_template from .jinja file into tokenizer_config.json + # NeMo's HF → NeMo import extracts chat_template to .jinja but doesn't preserve + # it in tokenizer_config.json. We restore it here for accuracy parity. + embed_chat_template_in_tokenizer_config(str(self), str(output_path)) + + return output_path + + def convert_state(self, source: Any, target: Any) -> Any: + """Convert state dict from NeMo format to HF format. + + Maps the weights from the NeMo model to the HF model according to + the appropriate mapping scheme for Puzzletron models. + + This method follows the same pattern as the importer but with reversed mappings: + 1. Start with default mapping + 2. Remove layernorm wildcards (will be replaced with per-layer mappings) + 3. Build per-layer specific mappings using helper function and reverse them + 4. Create transforms for weight conversions + + Args: + source: Source NeMo model + target: Target HF model + + Returns: + The target model with weights transferred from source + """ + mapping = self.default_mapping.copy() + + # Handle LayerNorm bias if present + if source.config.normalization == "LayerNorm": + mapping["decoder.final_layernorm.bias"] = "model.norm.bias" + + # Handle tied embeddings + if getattr(source.config, "share_embeddings_and_output_weights", False): + # Remove output_layer mapping if embeddings are tied + if "output_layer.weight" in mapping: + del mapping["output_layer.weight"] + + # Build per-layer specific mappings for heterogeneous support + attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs = ( + _build_puzzletron_mappings_and_transforms(source.config) + ) + + # Remove layernorm wildcards from default_mapping - these will be replaced with + # specific per-layer mappings based on each layer's architecture. + for pattern in [ + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + ]: + if pattern in mapping: + del mapping[pattern] + + # For exporter: reverse all mappings (HF -> NeMo becomes NeMo -> HF) + attn_mapping = {v: k for k, v in attn_mapping.items()} + ffn_mapping = {v: k for k, v in ffn_mapping.items()} + mamba_mapping = {v: k for k, v in mamba_mapping.items()} + moe_mapping = {v: k for k, v in moe_mapping.items()} + + # Add all layer-specific mappings + mapping.update(**attn_mapping) + mapping.update(**ffn_mapping) + mapping.update(**mamba_mapping) + mapping.update(**moe_mapping) + + # Create transforms from specifications (reversed for exporter) + transforms = [] + + # Helper to create split_qkv closure with proper layer index capture + def make_split_qkv_fn(layer_idx): + def split_qkv_fn(ctx, qkv): + return split_qkv_for_puzzletron(ctx, qkv, idx=layer_idx) + + return split_qkv_fn + + for spec in transform_specs: + if spec["transform_function"] == "merge_qkv_for_puzzletron": + # For exporter: split QKV (NeMo -> HF) + layer_idx = spec["kwargs"]["idx"] + transforms.append( + io.state_transform( + source_key=spec["target_key"], # NeMo key + target_key=spec["source_key"], # HF key + fn=make_split_qkv_fn(layer_idx), + ) + ) + elif spec["transform_function"] == "merge_fc1_for_moe": + # For exporter: split FC1 for MoE (NeMo -> HF) + transforms.append( + io.state_transform( + source_key=spec["target_key"], # NeMo key + target_key=spec["source_key"], # HF key + fn=TransformFns.split_fc1, + ) + ) + + # Add standard transforms for FC1 splitting and padding pruning + transforms.extend( + [ + io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + fn=TransformFns.split_fc1, + ), + io.state_transform( + source_key="embedding.word_embeddings.weight", + target_key="model.embed_tokens.weight", + fn=TransformFns.prune_padding, + ), + io.state_transform( + source_key="output_layer.weight", + target_key="lm_head.weight", + fn=TransformFns.prune_padding, + ), + ] + ) + + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=transforms, + ) + + @property + def tokenizer(self) -> "TokenizerSpec": + """Get the tokenizer from the NeMo model. + + Returns: + TokenizerSpec: Tokenizer from the NeMo model + """ + return io.load_context(str(self), subpath="model").tokenizer + + +__all__ = [ + "PuzzletronLlamaNemotronModel", +] diff --git a/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py b/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py new file mode 100644 index 0000000000..8d01ec9537 --- /dev/null +++ b/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py @@ -0,0 +1,729 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from megatron.core.transformer.spec_utils import ModuleSpec +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import ( + AutoTokenizer as NemoAutoTokenizer, +) +from nemo.collections.llm.gpt.model.base import GPTModel +from nemo.collections.llm.gpt.model.llama_nemotron import ( + HFLlamaNemotronImporter, + PuzzletronNemotronModelConfig, +) +from nemo.lightning import io, teardown +from nemo.lightning.io.state import TransformFns +from nemo.lightning.pytorch.utils import dtype_from_str +from nemo.utils.import_utils import safe_import +from transformers import AutoModelForCausalLM, AutoTokenizer + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( + PuzzletronAttentionConfig, + PuzzletronHeterogeneousTransformerConfig, + PuzzletronMLPConfig, + get_gpt_heterogeneous_layer_spec_puzzletron, +) + + +def convert_attention_config_from_cfg_object(attention_config, num_attention_heads, head_dim): + for unsupported_key in [ + "llama4", + "num_sink_tokens", + "sparsify", + "unshifted_sink", + "use_prefill_window_in_sink_attention", + ]: + if hasattr(attention_config, unsupported_key) and getattr( + attention_config, unsupported_key + ) not in [ + None, + False, + ]: + # + # if attention_config.get(unsupported_key, None) not in [None, False]: + raise NotImplementedError(f"{unsupported_key} is not supported") + window_size = attention_config.window_size if hasattr(attention_config, "window_size") else None + if window_size is not None: + window_size = (window_size, 0) + is_mamba = attention_config.mamba if hasattr(attention_config, "mamba") else False + n_heads_in_group = ( + attention_config.n_heads_in_group if hasattr(attention_config, "n_heads_in_group") else 1 + ) + if n_heads_in_group is None: + n_heads_in_group = 1 + return asdict( + PuzzletronAttentionConfig( + no_op=attention_config.no_op if hasattr(attention_config, "no_op") else False, + replace_with_linear=( + attention_config.replace_with_linear + if hasattr(attention_config, "replace_with_linear") + else False + ), + num_attention_heads=num_attention_heads, + num_query_groups=num_attention_heads // n_heads_in_group, + kv_channels=head_dim, + window_size=window_size, + multi_latent_attention=False, + is_mamba=is_mamba, + mamba_state_dim=( + attention_config.mamba.state_dim + if is_mamba and hasattr(attention_config.mamba, "state_dim") + else 128 + ), + mamba_head_dim=( + attention_config.mamba.head_dim + if is_mamba and hasattr(attention_config.mamba, "head_dim") + else 64 + ), + mamba_num_groups=( + attention_config.mamba.num_groups + if is_mamba and hasattr(attention_config.mamba, "num_groups") + else 8 + ), + mamba_num_heads=( + attention_config.mamba.num_heads + if is_mamba and hasattr(attention_config.mamba, "num_heads") + else None + ), + ) + ) + + +def convert_mlp_config_from_cfg_object(mlp_config, parallel_blocks): + """Convert MLP config from HF format to NeMo format. + + Args: + mlp_config: The MLP configuration object from HF + parallel_blocks: Parallel blocks configuration (not currently supported) + """ + if parallel_blocks is not None: + raise NotImplementedError("parallel_blocks is not supported") + if not hasattr(mlp_config, "gated") or mlp_config.gated is False: + raise NotImplementedError("notgated MLP is not supported") + + # Validate this block's activation function + if not hasattr(mlp_config, "hidden_act"): + raise ValueError(f"MLP config must have hidden_act attribute") + # if mlp_config.hidden_act != block_hidden_act: + # raise ValueError(f"MLP config hidden_act mismatch: config has {mlp_config.hidden_act}, expected {block_hidden_act}") + + if hasattr(mlp_config, "sparsify") and mlp_config.sparsify is not None: + raise NotImplementedError("sparsify is not supported") + is_moe = hasattr(mlp_config, "moe") and mlp_config.moe is not None + # Note: hidden_act is validated above but not stored in PuzzletronMLPConfig + # It will be used at the call site for the NeMo model config + return asdict( + PuzzletronMLPConfig( + no_op=mlp_config.no_op if hasattr(mlp_config, "no_op") else False, + replace_with_linear=mlp_config.replace_with_linear + if hasattr(mlp_config, "replace_with_linear") + else False, + ffn_hidden_size=mlp_config.intermediate_size + if hasattr(mlp_config, "intermediate_size") + else None, + num_moe_experts=( + mlp_config.moe.num_local_experts + if is_moe and hasattr(mlp_config.moe, "num_local_experts") + else None + ), + moe_shared_expert_intermediate_size=( + mlp_config.moe.shared_expert_intermediate_dim + if is_moe and hasattr(mlp_config.moe, "shared_expert_intermediate_dim") + else None + ), + moe_ffn_hidden_size=( + mlp_config.moe.expert_intermediate_dim + if is_moe and hasattr(mlp_config.moe, "expert_intermediate_dim") + else None + ), + moe_router_topk=( + mlp_config.moe.num_experts_per_tok + if is_moe and hasattr(mlp_config.moe, "num_experts_per_tok") + else 2 + ), + ) + ) + + +def convert_nemo_config_to_hf_decilm_config( + nemo_config: "PuzzletronNemotronModelConfig", + vocab_size: int, + eos_token_id: Union[int, List[int], None] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, +) -> "DeciLMConfig": + """Convert a NeMo PuzzletronNemotronModelConfig to HF DeciLMConfig. + + This function extracts the conversion logic from the exporter so it can be + used in unit tests without requiring file I/O. + + Args: + nemo_config: The NeMo config to convert + vocab_size: Vocabulary size for the HF config + eos_token_id: EOS token ID(s). Can be int or list of ints. + bos_token_id: BOS token ID + pad_token_id: PAD token ID + + Returns: + DeciLMConfig: The equivalent HF config + """ + + # Get preserved HF config metadata (stored as direct attribute) + # This enables lossless round-trip conversion HF → NeMo → HF + source_hf_config_metadata = getattr(nemo_config, "source_hf_config_metadata", None) or {} + + # Parse the heterogeneous layers config from JSON + block_configs = [] + + if ( + hasattr(nemo_config, "heterogeneous_layers_config_encoded_json") + and nemo_config.heterogeneous_layers_config_encoded_json + ): + try: + heterogeneous_config = json.loads(nemo_config.heterogeneous_layers_config_encoded_json) + raw_block_configs = heterogeneous_config.get("block_configs", []) + + for i, raw_block_config in enumerate(raw_block_configs): + attn_block = raw_block_config.get("attention", {}) + mlp_block = raw_block_config.get("mlp", {}) + + # Configure attention + attention_config = { + "no_op": attn_block.get("no_op", False), + "replace_with_linear": attn_block.get("replace_with_linear", False), + "sparsify": attn_block.get("sparsify", None), + "n_heads_in_group": attn_block.get( + "num_attention_heads", nemo_config.num_attention_heads + ) + // attn_block.get("num_query_groups", nemo_config.num_query_groups), + "window_length": attn_block.get("window_size", None), + "num_sink_tokens": attn_block.get("num_sink_tokens", None), + "use_prefill_window_in_sink_attention": attn_block.get( + "use_prefill_window_in_sink_attention", False + ), + "unshifted_sink": attn_block.get("unshifted_sink", False), + } + + # Handle Mamba: convert from NeMo flat structure to HF nested structure + if attn_block.get("is_mamba", False): + attention_config["mamba"] = { + "state_dim": attn_block.get("mamba_state_dim", 128), + "num_heads": attn_block.get( + "mamba_num_heads", nemo_config.num_attention_heads + ), + "head_dim": attn_block.get("mamba_head_dim", 64), + "num_groups": attn_block.get("mamba_num_groups", 8), + } + else: + attention_config["mamba"] = None + + # Handle Llama4: pass through as dict if present + attention_config["llama4"] = attn_block.get("llama4", None) + + # Configure FFN + ffn_config = { + "no_op": mlp_block.get("no_op", False), + "replace_with_linear": mlp_block.get("replace_with_linear", False), + "sparsify": mlp_block.get("sparsify", None), + "intermediate_size": mlp_block.get( + "ffn_hidden_size", nemo_config.ffn_hidden_size + ), + "gated": True, # Puzzletron uses gated activations + # Use the activation function name extracted from this block's config + "hidden_act": mlp_block.get("hidden_act", None), + } + + # Handle MoE: convert from NeMo flat structure to HF nested structure + num_moe_experts = mlp_block.get("num_moe_experts", None) + if num_moe_experts is not None: + ffn_config["moe"] = { + "num_local_experts": num_moe_experts, + "num_experts_per_tok": mlp_block.get("moe_router_topk", 1), + "expert_intermediate_dim": mlp_block.get("moe_ffn_hidden_size", 8192), + "shared_expert_intermediate_dim": mlp_block.get( + "moe_shared_expert_intermediate_size", 8192 + ), + } + else: + ffn_config["moe"] = None + + block_configs.append({"attention": attention_config, "ffn": ffn_config}) + except (json.JSONDecodeError, KeyError) as e: + raise ValueError(f"Could not parse heterogeneous config JSON: {e}") + else: + raise ValueError("No block configs found in source configuration") + + # Create rope scaling config + rope_scaling = { + "factor": nemo_config.scale_factor, + "low_freq_factor": getattr(nemo_config, "low_freq_factor", 1.0), + "high_freq_factor": getattr(nemo_config, "high_freq_factor", 4.0), + "original_max_position_embeddings": getattr(nemo_config, "old_context_len", 8192), + "rope_type": "llama3", + } + + # Get EOS token ID(s) - prefer preserved value from source HF config metadata or provided value + if eos_token_id is None: + eos_token_id = source_hf_config_metadata.get("eos_token_id", None) + + # Create DeciLM config + hf_config = DeciLMConfig( + block_configs=block_configs, + hidden_size=nemo_config.hidden_size, + max_position_embeddings=nemo_config.seq_length, + num_attention_heads=nemo_config.num_attention_heads, + num_hidden_layers=nemo_config.num_layers, + tie_word_embeddings=nemo_config.share_embeddings_and_output_weights, + vocab_size=vocab_size, + rms_norm_eps=nemo_config.layernorm_epsilon, + attention_bias=getattr(nemo_config, "attention_bias", False), + o_proj_bias=getattr( + nemo_config, "o_proj_bias", getattr(nemo_config, "attention_bias", False) + ), + rope_theta=nemo_config.rotary_base, + rope_scaling=rope_scaling, + position_embedding_type="rope", + architectures=["DeciLMForCausalLM"], + model_type="nemotron-nas", + eos_token_id=eos_token_id, + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + head_dim=nemo_config.kv_channels, + # Restore auto_map from preserved metadata (needed for trust_remote_code loading) + auto_map=source_hf_config_metadata.get( + "auto_map", + { + "AutoConfig": "configuration_decilm.DeciLMConfig", + "AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM", + }, + ), + # Restore dtype field from preserved metadata + dtype=source_hf_config_metadata.get("dtype", "bfloat16"), + ) + + return hf_config + + +def _config_to_dict(config) -> Dict[str, Any]: + """Convert a config object to a dictionary. + + Args: + config: Either an object with attributes or already a dictionary + + Returns: + Dictionary representation of the config + """ + if isinstance(config, dict): + return config + return vars(config) + + +def _build_puzzletron_mappings_and_transforms( + source_config: PuzzletronHeterogeneousTransformerConfig, +) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str], Dict[str, str], List[Dict[str, Any]]]: + """Build mappings and transform specifications for Puzzletron heterogeneous models. + + Args: + source_config: The Puzzletron heterogeneous transformer configuration + + Returns: + Tuple containing: + - attn_mapping: Attention layer mappings + - ffn_mapping: FFN layer mappings + - mamba_mapping: Mamba layer mappings + - moe_mapping: MoE layer mappings + - transform_specs: List of transform specifications with source_key, target_key, transform_function + """ + attn_mapping = {} + ffn_mapping = {} + mamba_mapping = {} + moe_mapping = {} + transform_specs = [] + + # Determine config type and extract block configs + is_hf_config = hasattr(source_config, "block_configs") + is_nemo_config = ( + hasattr(source_config, "heterogeneous_layers_config_encoded_json") + and source_config.heterogeneous_layers_config_encoded_json + ) + assert not (is_hf_config and is_nemo_config), "Cannot have both HF and NeMo config" + + if is_hf_config: + # HF config case (importer) + block_configs = source_config.block_configs + elif is_nemo_config: + # NeMo config case (exporter) - parse JSON + try: + heterogeneous_config = json.loads( + source_config.heterogeneous_layers_config_encoded_json + ) + block_configs = heterogeneous_config.get("block_configs", []) + except (json.JSONDecodeError, KeyError): + block_configs = [] + else: + block_configs = [] + + # Check if we found any block configs + if not block_configs: + raise ValueError( + "No block configs found in source configuration. " + "Expected either 'block_configs' attribute (HF config) or " + "'heterogeneous_layers_config_encoded_json' attribute (NeMo config) with valid block configs." + ) + + # TODO it is better (more stable) to use target.config.block_configs + for idx, block_config in enumerate(block_configs): + # Convert block config to dictionary + block_dict = _config_to_dict(block_config) + + # Extract attention and FFN configs (handle both HF "ffn" and NeMo "mlp" keys) + attn = block_dict.get("attention") + ffn = block_dict.get("ffn") or block_dict.get("mlp") + + # Convert sub-configs to dictionaries + attn_dict = _config_to_dict(attn) if attn else {} + ffn_dict = _config_to_dict(ffn) if ffn else {} + + # Process attention config + # Handle both HF (mamba) and NeMo (is_mamba) keys + is_mamba = attn_dict.get("mamba") or attn_dict.get("is_mamba") + + if not attn or attn_dict.get("no_op"): + value = None + elif attn_dict.get("replace_with_linear"): + value = f"decoder.layers.{idx}.self_attention.layer_norm_weight" + elif is_mamba is not None: + value = f"decoder.layers.{idx}.self_attention.in_proj.layer_norm_weight" + for mamba_key in [ + "dt_bias", + "A_log", + "D", + "in_proj.weight", + "conv1d.weight", + "conv1d.bias", + "norm.weight", + "out_proj.weight", + ]: + mamba_mapping[f"model.layers.{idx}.self_attn.mamba_mixer.{mamba_key}"] = ( + f"decoder.layers.{idx}.self_attention.{mamba_key}" + ) + else: + value = f"decoder.layers.{idx}.self_attention.linear_qkv.layer_norm_weight" + # Store transform spec for QKV merging + transform_specs.append( + { + "source_key": ( + f"model.layers.{idx}.self_attn.q_proj.weight", + f"model.layers.{idx}.self_attn.k_proj.weight", + f"model.layers.{idx}.self_attn.v_proj.weight", + ), + "target_key": f"decoder.layers.{idx}.self_attention.linear_qkv.weight", + "transform_function": "merge_qkv_for_puzzletron", + "kwargs": {"idx": idx}, + } + ) + + if value is not None: + attn_mapping[f"model.layers.{idx}.input_layernorm.weight"] = value + + # Process FFN config + # Handle both HF (moe, moe.shared_expert_intermediate_dim) and NeMo (num_moe_experts, moe_shared_expert_intermediate_size) keys + moe_config = ffn_dict.get("moe") or ffn_dict.get("num_moe_experts") + shared_expert_size = None + if moe_config: + # Convert moe_config to dict if it's an object (HF case) + moe_dict = ( + _config_to_dict(moe_config) if not isinstance(moe_config, (int, type(None))) else {} + ) + shared_expert_size = moe_dict.get("shared_expert_intermediate_dim") or ffn_dict.get( + "moe_shared_expert_intermediate_size" + ) + + if not ffn or ffn_dict.get("no_op"): + value = None + elif ffn_dict.get("replace_with_linear"): + value = f"decoder.layers.{idx}.mlp.layer_norm_weight" + elif moe_config is not None: + value = f"decoder.layers.{idx}.pre_mlp_layernorm.weight" + moe_mapping[f"model.layers.{idx}.mlp.router.weight"] = ( + f"decoder.layers.{idx}.mlp.router.weight" + ) + # Store transform spec for MoE expert FC1 merging + transform_specs.append( + { + "source_key": ( + f"model.layers.{idx}.mlp.experts.*.gate_proj.weight", + f"model.layers.{idx}.mlp.experts.*.up_proj.weight", + ), + "target_key": f"decoder.layers.{idx}.mlp.experts.local_experts.*.linear_fc1.weight", + "transform_function": "merge_fc1_for_moe", + "kwargs": {}, + } + ) + moe_mapping[f"model.layers.{idx}.mlp.experts.*.down_proj.weight"] = ( + f"decoder.layers.{idx}.mlp.experts.local_experts.*.linear_fc2.weight" + ) + # Check for shared expert + if shared_expert_size not in [None, 0]: + # Store transform spec for MoE shared expert FC1 merging + transform_specs.append( + { + "source_key": ( + f"model.layers.{idx}.mlp.shared_expert.gate_proj.weight", + f"model.layers.{idx}.mlp.shared_expert.up_proj.weight", + ), + "target_key": f"decoder.layers.{idx}.mlp.shared_experts.linear_fc1.weight", + "transform_function": "merge_fc1_for_moe", + "kwargs": {}, + } + ) + moe_mapping[f"model.layers.{idx}.mlp.shared_expert.down_proj.weight"] = ( + f"decoder.layers.{idx}.mlp.shared_experts.linear_fc2.weight" + ) + else: + value = f"decoder.layers.{idx}.mlp.linear_fc1.layer_norm_weight" + + if value is not None: + ffn_mapping[f"model.layers.{idx}.post_attention_layernorm.weight"] = value + + return attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs + + +def merge_qkv_for_puzzletron( + ctx: io.state.TransformCTX, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + idx: Optional[int] = None, +): + """ + Merge q, k, v to interleave-concatenated qkv. + - Modified version of nemo.lightning.io.state.TransformFns.merge_qkv for Puzzletron + - idx can be provided to fetch megatron_config for a specific layer + - heads_per_group is derived from the shape of q and k, instead of calculating (head_num // num_query_groups) from config values + - num_query_groups is not fetched from a global config value, but calculated from head_num and heads_per_group + + Example: import HF {q|k|v}_proj to layer linear_qkv + """ + if idx is not None: + megatron_config = ctx.target.decoder.layers[idx].config + else: + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + heads_per_group = ( + q.shape[0] // k.shape[0] + ) # NOTE: This is important to support heterogeneous attention + num_query_groups = head_num // heads_per_group + hidden_size = megatron_config.hidden_size + head_size = megatron_config.kv_channels + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +def split_qkv_for_puzzletron( + ctx: io.state.TransformCTX, qkv: torch.Tensor, idx: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split interleave-concatenated qkv to separate q, k, v. + - Inverse operation of merge_qkv_for_puzzletron for Puzzletron + - idx can be provided to fetch megatron_config for a specific layer + - heads_per_group is derived from the shape of qkv, instead of calculating from config values + - num_query_groups is not fetched from a global config value, but calculated from head_num and heads_per_group + + Example: export NeMo layer linear_qkv to HF {q|k|v}_proj + """ + if idx is not None: + megatron_config = ctx.source.decoder.layers[idx].config + else: + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + head_size = megatron_config.kv_channels + # hidden_size = megatron_config.hidden_size + + # Calculate qkv_total_dim from the actual qkv tensor shape + # qkv shape is [head_size * (head_num + 2 * num_query_groups), hidden_size] + qkv_total_dim = qkv.shape[0] // head_size + num_query_groups = (qkv_total_dim - head_num) // 2 + heads_per_group = head_num // num_query_groups + + # Reshape qkv to 3D: [qkv_total_dim, head_size, hidden_size] + qkv = qkv.reshape([qkv_total_dim, head_size, -1]) + + # when converting base model (linear_qkv), hidden size = megatron_config.hidden_size + # when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r + actual_hidden_size = qkv.size(-1) + + # Create slice indices for q, k, v + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = qkv[q_slice].reshape(-1, actual_hidden_size).cpu() + k_proj = qkv[k_slice].reshape(-1, actual_hidden_size).cpu() + v_proj = qkv[v_slice].reshape(-1, actual_hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +def dtype_from_dict(config_dict): + """ + Extracts torch dtype from a HF config. + Handles both 'torch_dtype' (old format) and 'dtype' (new format). + """ + # Try torch_dtype first (old format), then dtype (new format) + if "torch_dtype" in config_dict: + torch_dtype = config_dict["torch_dtype"] + elif "dtype" in config_dict: + torch_dtype = config_dict["dtype"] + else: + raise ValueError("Expected config dict to have attr `torch_dtype` or `dtype`") + + if isinstance(torch_dtype, torch.dtype): + return torch_dtype + elif isinstance(torch_dtype, str): + return dtype_from_str(torch_dtype) + else: + raise ValueError(f"dtype is not of type str/torch.dtype, got {type(torch_dtype)}") + + +def copy_puzzletron_python_files_to_decilm_checkpoint(output_path: str) -> None: + """Copy custom Python files from puzzle_tools package to output directory. + + Puzzletron models require custom Python files (configuration_decilm.py, + modeling_decilm.py, etc.) to be present in the checkpoint directory for + loading with transformers.AutoModel. + + This function copies all Python files from puzzle_tools/deci_lm_hf_code/ + to ensure the exported checkpoint is fully functional. + + Args: + output_path: Directory where HF model is being saved + """ + import logging + import shutil + from pathlib import Path + + # Get the puzzle_tools/deci_lm_hf_code directory + # Navigate from this file: export/MCore/llama_nemotron_utils.py -> v1/puzzle_tools/deci_lm_hf_code/ + package_dir = Path(__file__).parent.parent.parent / "puzzle_tools" / "deci_lm_hf_code" + + if not package_dir.exists(): + logging.warning( + f"Custom files directory not found: {package_dir}. " + f"Exported checkpoint may not be loadable without these files." + ) + return + + # Copy all Python files from the package + output_dir = Path(output_path) + copied_files = [] + for src_file in package_dir.glob("*.py"): + dest_file = output_dir / src_file.name + shutil.copy2(src_file, dest_file) + copied_files.append(src_file.name) + + logging.info(f"Copied {len(copied_files)} custom Python files to {output_path}") + logging.debug(f"Custom files copied: {', '.join(sorted(copied_files)[:5])}...") # Show first 5 + + +def embed_chat_template_in_tokenizer_config(nemo_checkpoint_path: str, output_path: str) -> None: + """Embed chat_template from .jinja file into tokenizer_config.json. + + NeMo's HF → NeMo import extracts chat_template to a separate .jinja file + but doesn't preserve it in tokenizer_config.json. This causes accuracy drops + in evaluation. This function restores it by: + 1. Reading chat_template.jinja from the NeMo checkpoint + 2. Embedding it into the exported tokenizer_config.json + + Args: + nemo_checkpoint_path: Path to the NeMo checkpoint (.nemo file/directory) + output_path: Directory where HF model is being saved + """ + import logging + from pathlib import Path + + # Path to NeMo checkpoint tokenizer files + nemo_checkpoint = Path(nemo_checkpoint_path) + nemo_chat_template_jinja = ( + nemo_checkpoint / "context" / "nemo_tokenizer" / "chat_template.jinja" + ) + + # Path to exported tokenizer config + output_dir = Path(output_path) + output_tokenizer_config = output_dir / "tokenizer_config.json" + + # Check if both files exist + if not nemo_chat_template_jinja.exists(): + logging.debug( + f"No chat_template.jinja found in NeMo checkpoint at {nemo_chat_template_jinja}" + ) + return + + if not output_tokenizer_config.exists(): + logging.warning(f"tokenizer_config.json not found at {output_tokenizer_config}") + return + + # Read chat_template from .jinja file + chat_template_content = nemo_chat_template_jinja.read_text() + + # Load tokenizer_config.json + with open(output_tokenizer_config, "r") as f: + tokenizer_config = json.load(f) + + # Check if chat_template is already embedded (shouldn't be, but be safe) + if "chat_template" in tokenizer_config: + logging.debug("chat_template already embedded in tokenizer_config.json, skipping") + return + + # Embed the chat_template + tokenizer_config["chat_template"] = chat_template_content + + # Save updated tokenizer_config.json + with open(output_tokenizer_config, "w") as f: + json.dump(tokenizer_config, f, indent=2, ensure_ascii=False) + + logging.info(f"✓ Embedded chat_template from NeMo checkpoint into tokenizer_config.json") + logging.debug(f" Template length: {len(chat_template_content)} characters") diff --git a/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py b/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py new file mode 100644 index 0000000000..11a8798ba6 --- /dev/null +++ b/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict + +import torch +from megatron.core.transformer.spec_utils import ModuleSpec +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import ( + AutoTokenizer as NemoAutoTokenizer, +) +from nemo.collections.llm.gpt.model.base import GPTModel +from nemo.collections.llm.gpt.model.llama_nemotron import HFLlamaNemotronImporter +from nemo.lightning import io, teardown +from nemo.lightning.io.state import TransformFns +from nemo.utils.import_utils import safe_import +from transformers import AutoModelForCausalLM, AutoTokenizer + +from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( + PuzzletronAttentionConfig, + PuzzletronHeterogeneousTransformerConfig, + PuzzletronMLPConfig, + get_gpt_heterogeneous_layer_spec_puzzletron, +) + + +def convert_attention_config_from_cfg_object(attention_config, num_attention_heads, head_dim): + for unsupported_key in [ + "llama4", + "num_sink_tokens", + "sparsify", + "unshifted_sink", + "use_prefill_window_in_sink_attention", + ]: + if hasattr(attention_config, unsupported_key) and getattr( + attention_config, unsupported_key + ) not in [ + None, + False, + ]: + # + # if attention_config.get(unsupported_key, None) not in [None, False]: + raise NotImplementedError(f"{unsupported_key} is not supported") + window_size = attention_config.window_size if hasattr(attention_config, "window_size") else None + if window_size is not None: + window_size = (window_size, 0) + is_mamba = attention_config.mamba if hasattr(attention_config, "mamba") else False + n_heads_in_group = ( + attention_config.n_heads_in_group if hasattr(attention_config, "n_heads_in_group") else 1 + ) + if n_heads_in_group is None: + n_heads_in_group = 1 + return asdict( + PuzzletronAttentionConfig( + no_op=attention_config.no_op if hasattr(attention_config, "no_op") else False, + replace_with_linear=( + attention_config.replace_with_linear + if hasattr(attention_config, "replace_with_linear") + else False + ), + num_attention_heads=num_attention_heads, + num_query_groups=num_attention_heads // n_heads_in_group, + kv_channels=head_dim, + window_size=window_size, + multi_latent_attention=False, + is_mamba=is_mamba, + mamba_state_dim=( + attention_config.mamba.state_dim + if is_mamba and hasattr(attention_config.mamba, "state_dim") + else 128 + ), + mamba_head_dim=( + attention_config.mamba.head_dim + if is_mamba and hasattr(attention_config.mamba, "head_dim") + else 64 + ), + mamba_num_groups=( + attention_config.mamba.num_groups + if is_mamba and hasattr(attention_config.mamba, "num_groups") + else 8 + ), + mamba_num_heads=( + attention_config.mamba.num_heads + if is_mamba and hasattr(attention_config.mamba, "num_heads") + else None + ), + ) + ) + + +def convert_mlp_config_from_cfg_object(mlp_config, parallel_blocks, default_hidden_act): + if parallel_blocks is not None: + raise NotImplementedError("parallel_blocks is not supported") + if not hasattr(mlp_config, "gated") or mlp_config.gated is False: + raise NotImplementedError("non-gated MLP is not supported") + if not hasattr(mlp_config, "hidden_act") or mlp_config.hidden_act not in [default_hidden_act]: + raise NotImplementedError(f"all mlps must have the same activation ({default_hidden_act})") + if hasattr(mlp_config, "sparsify") and mlp_config.sparsify is not None: + raise NotImplementedError("sparsify is not supported") + is_moe = hasattr(mlp_config, "moe") and mlp_config.moe is not None + return asdict( + PuzzletronMLPConfig( + no_op=mlp_config.no_op if hasattr(mlp_config, "no_op") else False, + replace_with_linear=mlp_config.replace_with_linear + if hasattr(mlp_config, "replace_with_linear") + else False, + ffn_hidden_size=mlp_config.intermediate_size + if hasattr(mlp_config, "intermediate_size") + else None, + num_moe_experts=( + mlp_config.moe.num_local_experts + if is_moe and hasattr(mlp_config.moe, "num_local_experts") + else None + ), + moe_shared_expert_intermediate_size=( + mlp_config.moe.shared_expert_intermediate_dim + if is_moe and hasattr(mlp_config.moe, "shared_expert_intermediate_dim") + else None + ), + moe_ffn_hidden_size=( + mlp_config.moe.expert_intermediate_dim + if is_moe and hasattr(mlp_config.moe, "expert_intermediate_dim") + else None + ), + moe_router_topk=( + mlp_config.moe.num_experts_per_tok + if is_moe and hasattr(mlp_config.moe, "num_experts_per_tok") + else 2 + ), + ) + ) diff --git a/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py b/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py new file mode 100644 index 0000000000..ec011ff288 --- /dev/null +++ b/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py @@ -0,0 +1,928 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import asdict, dataclass, field, fields +from pathlib import Path +from typing import Any, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.gpt.gpt_layer_specs import ( + LayerType, + LNImpl, + TransformerBlockSubmodules, + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, + get_num_layers_to_build, + get_transformer_layer_offset, +) +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.post_training.modelopt.layers import Linear +from megatron.core.process_groups_config import ModelCommProcessGroups +from megatron.core.quantization.utils import ( + kitchen_quantization_recipe_config, + load_quantization_recipe, +) +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.tensor_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + _initialize_affine_weight_cpu, +) +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker +from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version + +# from megatron.core.activations import squared_relu #for megatron 0.14 version in future NeMo containers +from megatron.training.activations import squared_relu +from nemo.collections.llm.gpt.model.llama import Llama31Config70B +from packaging.version import Version as PkgVersion +from torch import Tensor +from torch.nn.parameter import Parameter + +try: + import transformer_engine as te # pylint: disable=unused-import + from megatron.core.extensions.transformer_engine import ( + TELayerNormColumnParallelLinear, + TELinear, + TENorm, + TERowParallelLinear, + _get_extra_te_kwargs, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +# TODO: check sharded_state_dict_keys_map => only if TE is disabled +# TODO: parallel Blocks +# TODO: multimodal +# https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/vlm/neva/model/base.py +# https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/vlm/qwen2vl/model/base.py + + +# NOTE based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py#L144 +# TODO: what is the difference between this and the referenced one? +def _get_sharded_state_dict_keys_map( + block_config: "PuzzletronTransformerBlockConfig", use_transformer_engine: bool +): + """Generate a mapping of sharded state dictionary keys for Puzzletron transformer blocks. + + This function is a specialized version of the original Megatron-LM + `_get_sharded_state_dict_keys_map` function, adapted for Puzzletron's + heterogeneous transformer architecture with Mamba support. + + Key differences from the original: + - **Mamba Support**: Adds mapping for Mamba layers (`mixer.norm_`) + - **Enhanced Logic**: Uses if-elif-else structure instead of multiple if statements + - **No-op Handling**: Explicit handling of no-op attention and MLP cases + - **Simplified**: Removes `num_query_groups` check (handled in main logic) + - **Config Type**: Uses `PuzzletronTransformerBlockConfig` instead of `TransformerBlockConfig` + + Args: + block_config: Puzzletron transformer block configuration + use_transformer_engine: Whether to use Transformer Engine optimizations + + Returns: + dict: A dictionary mapping sharded state dictionary keys + """ + mapping = {} + if not use_transformer_engine: + if block_config.attention.replace_with_linear: + mapping.update({"input_layernorm.": "self_attention.layer_norm_"}) + elif block_config.attention.is_mamba: # Mamba, not sure about this + mapping.update({"input_layernorm.": "mixer.norm_"}) + elif not block_config.attention.no_op: # MHA and MLA + mapping.update({"input_layernorm.": "self_attention.linear_qkv.layer_norm_"}) + else: # No-op + pass + + if block_config.mlp.ffn_hidden_size is not None: # FFN + mapping.update({"pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_"}) + elif block_config.mlp.replace_with_linear: # Linear + mapping.update({"pre_mlp_layernorm.": "mlp.layer_norm_"}) + else: # No-op, MoE + pass + return mapping + + +# NOTE: new class +@dataclass +class PuzzletronSubblockConfig: + """Base configuration class for Puzzletron transformer subblocks. + + This is the base class for attention and MLP configurations in Puzzletron's + heterogeneous transformer architecture. It provides common functionality + for subblock configurations including no-op and linear replacement options. + + Key differences from the original Megatron-LM subblock configs: + - **Enhanced Building**: Uses `build_config_from_dict()` with main config fallback + - **Validation**: Includes `__post_init__()` validation for mutual exclusivity + - **Flexibility**: Supports both no-op and linear replacement modes + + Attributes: + no_op: Whether this subblock should be a no-op operation + replace_with_linear: Whether to replace the subblock with a single linear layer + """ + + no_op: bool = False + replace_with_linear: bool = False + + @classmethod + def build_config_from_dict( + cls, + subblock_config_dict: dict[str, Any], + main_config: "PuzzletronHeterogeneousTransformerConfig", + ): + field_names = {f.name for f in fields(cls)} + subblock_config_dict = {k: v for k, v in subblock_config_dict.items() if k in field_names} + # getting default values from the main config (if not overridden in the subblock config) + for field_name in field_names: + # note that MLA fields are also not in the main_config + if field_name not in subblock_config_dict and hasattr(main_config, field_name): + subblock_config_dict[field_name] = getattr(main_config, field_name) + return cls(**subblock_config_dict) + + def __post_init__(self) -> None: + assert not (self.no_op and self.replace_with_linear), ( + "at most one of no_op, replace_with_linear can be True" + ) + + +@dataclass +class PuzzletronAttentionConfig(PuzzletronSubblockConfig): + """Configuration parameters for the self-attention part of a Puzzletron transformer block. + + This class extends the original Megatron-LM AttentionConfig with support for + Mamba layers and enhanced Multi-Latent Attention (MLA) configurations. + + Key differences from the original AttentionConfig: + - **Mamba Support**: Adds `is_mamba` flag and Mamba-specific parameters + - **Enhanced MLA**: Extended MLA parameters with LoRA ranks and head dimensions + - **Context Parallelism**: Adds `cp_comm_type` for attention context parallelism + - **Validation**: Enhanced `__post_init__()` with Mamba-MLA mutual exclusivity check + - **Flexibility**: Supports MHA, MLA, and Mamba attention types in one config + + Attributes: + # MHA (Multi-Head Attention) parameters + num_attention_heads: Number of attention heads + num_query_groups: Number of query groups for grouped query attention + kv_channels: Key-value projection dimension + window_size: Sliding window size for local attention + + # MLA (Multi-Latent Attention) parameters + multi_latent_attention: Whether to use MLA instead of MHA + q_lora_rank: LoRA rank for query projections + kv_lora_rank: LoRA rank for key-value projections + qk_head_dim: Query-key head dimension + qk_pos_emb_head_dim: Query-key positional embedding head dimension + v_head_dim: Value head dimension + + # Context parallelism + cp_comm_type: Communication type for context parallelism + + # Mamba parameters + is_mamba: Whether to use Mamba instead of attention + mamba_state_dim: Mamba state dimension + mamba_head_dim: Mamba head dimension + mamba_num_groups: Number of groups in Mamba + mamba_num_heads: Number of heads in Mamba (auto-calculated if None) + """ + + # all attributes, except for is_mamba are part of TransformerConfig/MLATransformerConfig + # MHA + num_attention_heads: Optional[int] = None + num_query_groups: Optional[int] = None + kv_channels: Optional[int] = None + window_size: Optional[Tuple[int, int]] = None + # MLA (Note that for MLA we have to instantiate a MLATransformerConfig, since there is a isinstance in attention.py) + multi_latent_attention: bool = False + q_lora_rank: int = 512 + kv_lora_rank: int = 512 + qk_head_dim: int = 128 + qk_pos_emb_head_dim: int = 64 + v_head_dim: int = 128 + # for attention context parallelism (ignored for mamba) + cp_comm_type: str = "p2p" + # Mamba + is_mamba: bool = False # new + mamba_state_dim: int = 128 + mamba_head_dim: int = 64 + mamba_num_groups: int = 8 + mamba_num_heads: Optional[int] = None + + def __post_init__(self) -> None: + super().__post_init__() + if self.no_op or self.replace_with_linear: + self.is_mamba = False + self.num_attention_heads = 8 + self.multi_latent_attention = False + if self.is_mamba: + if self.num_attention_heads is None or self.num_attention_heads == 0: + self.num_attention_heads = 8 # to avoid division by zero + assert not (self.is_mamba and self.multi_latent_attention), ( + "Mamba and MLA cannot be used together" + ) + + +@dataclass +class PuzzletronMLPConfig(PuzzletronSubblockConfig): + """Configuration parameters for the MLP part of a Puzzletron transformer block. + + This class extends the original Megatron-LM MLPConfig with enhanced + Mixture of Experts (MoE) support and improved configuration building. + + Key differences from the original MLPConfig: + - **Enhanced MoE**: Extended MoE parameters with shared expert support + - **Validation**: Includes `__post_init__()` validation for no-op/linear modes + - **Building**: Uses `build_config_from_dict()` with main config fallback + - **Flexibility**: Supports standard MLP, MoE, no-op, and linear replacement modes + + Attributes: + # Standard MLP parameters + ffn_hidden_size: MLP intermediate size (hidden dimension) + + # MoE (Mixture of Experts) parameters + num_moe_experts: Number of expert networks in MoE + moe_shared_expert_intermediate_size: Size of shared expert intermediate layer + moe_ffn_hidden_size: Hidden size for MoE expert networks + moe_router_topk: Number of top-k experts to route tokens to + """ + + # all attributes are part of TransformerConfig + ffn_hidden_size: Optional[int] = None + # MoE + num_moe_experts: Optional[int] = None + moe_shared_expert_intermediate_size: Optional[int] = None + moe_ffn_hidden_size: Optional[int] = None + moe_router_topk: int = 2 + + def __post_init__(self) -> None: + super().__post_init__() + if self.no_op or self.replace_with_linear: + self.ffn_hidden_size = None + self.num_moe_experts = None + self.moe_ffn_hidden_size = None + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/transformer/heterogeneous/heterogeneous_config.py#L134 +@dataclass +class PuzzletronTransformerBlockConfig: + """Configuration for a single Puzzletron transformer block in a heterogeneous model. + + This class represents the configuration for one transformer block, containing + both attention and MLP subblock configurations. It's based on the original + Megatron-LM TransformerBlockConfig but uses Puzzletron-specific subblock configs. + + Key differences from the original TransformerBlockConfig: + - **Puzzletron Subblocks**: Uses `PuzzletronAttentionConfig` and `PuzzletronMLPConfig` + - **Enhanced Building**: Uses `build_from_dict()` with main config fallback + - **Mamba Support**: Supports Mamba layers through attention config + - **MoE Support**: Enhanced MoE support through MLP config + - **Flexibility**: Supports all Puzzletron attention and MLP variants + + Attributes: + attention: Configuration for the attention subblock (MHA, MLA, or Mamba) + mlp: Configuration for the MLP subblock (standard MLP or MoE) + """ + + attention: PuzzletronAttentionConfig + mlp: PuzzletronMLPConfig + + @classmethod + def build_from_dict( + cls, block: dict[str, Any], main_config: "PuzzletronHeterogeneousTransformerConfig" + ): + if "mlp" in block: + mlp = block["mlp"] + elif "ffn" in block: + mlp = block["ffn"] + else: + raise ValueError(f"mlp/ffn not found in block: {block}") + + return cls( + attention=PuzzletronAttentionConfig.build_config_from_dict( + subblock_config_dict=block["attention"], main_config=main_config + ), + mlp=PuzzletronMLPConfig.build_config_from_dict( + subblock_config_dict=mlp, main_config=main_config + ), + ) + + +@dataclass +class PuzzletronMambaTransformerConfig(TransformerConfig): + """Configuration for Puzzletron Mamba-only transformer models. + + This class extends the base TransformerConfig for models that use + Mamba layers exclusively instead of attention mechanisms. It inherits + all standard transformer configuration parameters from TransformerConfig. + + Key differences from standard TransformerConfig: + - **Mamba Focus**: Designed specifically for Mamba-based architectures + - **Inheritance**: Inherits all standard transformer parameters + - **Simplicity**: Currently a pass-through class for future Mamba-specific extensions + + Note: This class is currently minimal and inherits all functionality + from the base TransformerConfig. Future versions may add Mamba-specific + configuration parameters as needed. + """ + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/transformer/heterogeneous/heterogeneous_config.py#L147 +@dataclass +class PuzzletronHeterogeneousTransformerConfig(TransformerConfig): + """Configuration object for Puzzletron heterogeneous transformers. + + This class extends the original Megatron-LM HeterogeneousTransformerConfig with + enhanced support for Mamba layers and improved configuration management. + + Key differences from the original HeterogeneousTransformerConfig: + - **Mamba Support**: Adds Mamba-specific parameters for state-space models + - **Enhanced Block Configs**: Uses `PuzzletronTransformerBlockConfig` with Mamba support + - **Improved Building**: Enhanced `__post_init__()` with better config validation + - **Flexibility**: Supports all Puzzletron attention and MLP variants + + Heterogeneous models refer to transformer architectures where individual layers can differ + in configuration. Specifically: + - Attention layers can be MHA, MLA, Mamba, Linear, or No-op (all with their own config) + - MLP layers can be MoE, MLP, Linear, or No-op (all with their own config) + - Layers can have parallel blocks that run simultaneously and sum their outputs + + Mamba Parameters (shared across all Mamba layers): + d_conv: Convolution dimension for Mamba + expand: Expansion factor for Mamba hidden dimension + D_has_hdim: Whether D matrix has hidden dimension + rmsnorm: Whether to use RMS normalization + norm_before_gate: Whether to normalize before gating + dt_min/max/scale: Delta time parameters for Mamba + bias/conv_bias: Bias settings for Mamba layers + chunk_size: Chunk size for Mamba processing + """ + + heterogeneous_layers_config_path: str = "" + """Path to the json file containing the heterogeneous block specs.""" + + heterogeneous_layers_config_encoded_json: str = "" + """The contents of the json file containing the heterogeneous block specs. It will be read from + heterogeneous_layers_config_path at first, then saved forever inside the model checkpoint.""" + + per_block_parameters: list[PuzzletronTransformerBlockConfig] = field(init=False) + """Configuration parameters for each of the transformer blocks in a + heterogeneous transformer.""" + + # all of these can be used to instantiate a MambaMixer, they are shared for all Mamba layers + d_conv: int = 4 + expand: int = 2 + D_has_hdim: bool = False + rmsnorm: bool = True + norm_before_gate: bool = False + dt_min: float = 0.001 + dt_max: float = 0.1 + dt_scale: float = 1.0 + bias: bool = False + conv_bias: bool = True + chunk_size: int = 128 + + def __post_init__(self) -> None: + if self.kv_channels is None and self.num_attention_heads == 0: + self.num_attention_heads = 8 # to avoid division by zero + # Type assertion to help mypy understand the type after the check + assert isinstance(self.num_attention_heads, int), "num_attention_heads must be an integer" + if self.heterogeneous_layers_config_encoded_json in ("", None): + assert self.heterogeneous_layers_config_path not in ( + None, + "", + ), ( + "heterogeneous_layers_config_path is required, if heterogeneous_layers_config_encoded_json is not provided" + ) + self.heterogeneous_layers_config_encoded_json = Path( + self.heterogeneous_layers_config_path + ).read_text() + hf_config_dict: dict[str, Any] = json.loads(self.heterogeneous_layers_config_encoded_json) + block_list = hf_config_dict["block_configs"] + # TODO: should we change the definition of num_layers? it can be sum(mlp/attention) rather than uneven blocks + if self.num_layers is None or self.num_layers == 0: + self.num_layers = len(block_list) + # Type assertion to help mypy understand the type after the check + assert isinstance(self.num_layers, int), "num_layers must be an integer" + assert self.num_layers == len(block_list), ( + "num_layers must match the number of blocks in the json file" + ) + super().__post_init__() + self.heterogeneous_block_specs = True + self.heterogeneous_dist_checkpoint = True # TODO: check if this is correct/needed + self.per_block_parameters = [ + PuzzletronTransformerBlockConfig.build_from_dict(block=block, main_config=self) + for block in block_list + ] + + # TODO add parallel blocks support + def get_config_for_layer( + self, layer_number: int + ) -> TransformerConfig | MLATransformerConfig | PuzzletronMambaTransformerConfig: + """ + Get the config for the given layer number. + Based on the layer number, the corresponding block config is returned, + overriding the main config's value. + + Returns: + TransformerConfig: For standard transformer layers + MLATransformerConfig: For MLA layers + PuzzletronMambaTransformerConfig: For Mamba layers + """ + layer_idx = layer_number - 1 # layer number starts from 1 + if layer_idx < 0 or layer_idx >= len(self.per_block_parameters): + raise ValueError( + f"Invalid layer number: {layer_number}. Should be in " + f"range [1, {len(self.per_block_parameters)}]." + ) + block_config = self.per_block_parameters[layer_idx] + + # Determine which config class to use based on the block configuration + if block_config.attention.is_mamba: + config_class = PuzzletronMambaTransformerConfig + elif block_config.attention.multi_latent_attention: + config_class = MLATransformerConfig + else: + config_class = TransformerConfig + + # Get all available fields from the attention and MLP configs + attention_fields = {f.name for f in fields(block_config.attention)} + mlp_fields = {f.name for f in fields(block_config.mlp)} + + # Get all available fields from the target config class + target_config_fields = {f.name for f in fields(config_class)} + + # Start with the base config + transformer_config_dict = asdict(self) + + # Remove keys that are not in the target config class + transformer_config_dict = { + k: v for k, v in transformer_config_dict.items() if k in target_config_fields + } + + # Update with all available attention config values (if they exist in target config) + for field_name in attention_fields: + if field_name in target_config_fields: + transformer_config_dict[field_name] = getattr(block_config.attention, field_name) + + # Update with all available MLP config values (if they exist in target config) + for field_name in mlp_fields: + if field_name in target_config_fields: + transformer_config_dict[field_name] = getattr(block_config.mlp, field_name) + + if transformer_config_dict["num_moe_experts"] is None: + # to pass __post_init__ of config_class + transformer_config_dict["expert_model_parallel_size"] = 1 + config = config_class(**transformer_config_dict) + + return config + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/ba97a7e282a8478a02d012bc9b9e45f3a6be216e/megatron/core/extensions/transformer_engine.py#L449 +class WrappedTENormLinear(TELayerNormColumnParallelLinear): + """A wrapper around TELayerNormColumnParallelLinear with simplified interface and forced configurations. + + This wrapper simplifies the interface of TELayerNormColumnParallelLinear by: + 1. Taking only a config object instead of individual parameters + 2. Forcing specific configurations (tp_group=None, tp_size=1, etc.) for compatibility + 3. Adding version checks for Transformer Engine features + 4. Providing a cleaner interface for heterogeneous transformer models + + Key differences from TELayerNormColumnParallelLinear: + - Simplified constructor: only requires config and optional unused parameters + - Forces tensor parallel settings: tp_group=None, tp_size=1, tp_rank=0 + - Automatically sets input_size=output_size=config.hidden_size + - Adds version checks for TE features (delay_wgrad_compute, normalization, symmetric_ar_type) + - Forces bias=False, skip_bias_add=False for consistency + - Disables gather_output (raises error if True) + - Uses simplified init_method=lambda w: None + + This wrapper is designed for use in heterogeneous transformer architectures where + individual layers may have different configurations but need a consistent interface. + """ + + def __init__( + self, + config, + layer_number=None, # unused + model_comm_pgs=None, # unused + cp_comm_type=None, # unused + tp_group=None, # unused + tp_comm_buffer_name=None, + gather_output=False, # unused + ): + # unfortunately, TELayerNormColumnParallelLinear sets tp_group and forcing it to be None requires to copy/paste __init__ + if not HAVE_TE: + raise ImportError( + "Transformer Engine is not installed. " + "Please install it with `pip install transformer-engine`." + ) + + self.config = config + + if gather_output: + raise ValueError("Transformer Engine linear layers do not support gather_output = True") + + skip_bias_add = False + bias = False + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + extra_kwargs = _get_extra_te_kwargs(config) + self.tp_size = 1 + self.tp_rank = 0 + + if self.config.delay_wgrad_compute: + if is_te_min_version("2.3.0"): + extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute + else: + raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") + + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` + if is_te_min_version("0.11.0"): + extra_kwargs["normalization"] = self.config.normalization + elif self.config.normalization != "LayerNorm": + te_version = get_te_version() + raise ValueError( + f"Transformer Engine v{te_version} does not support {self.config.normalization}." + ) + + if self.config.symmetric_ar_type is not None: + assert is_torch_min_version("2.7.0a0"), "Must have at least torch version 2.7 or higher" + assert is_te_min_version("2.3.0") or get_te_version() == PkgVersion( + "2.3.0.dev0+39c0e70" + ), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce" + extra_kwargs["symmetric_ar_type"] = self.config.symmetric_ar_type + + output_size = config.hidden_size + input_size = config.hidden_size + # calling te.pytorch.LayerNormLinear's __init__ + super(TELayerNormColumnParallelLinear, self).__init__( + in_features=input_size, + out_features=output_size, + eps=self.config.layernorm_epsilon, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=None, + tp_size=1, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=lambda w: None, + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=None, + return_layernorm_output=False, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + **extra_kwargs, + ) + + if config.use_cpu_initialization: + output_size_per_partition = output_size + _ = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + output_size_per_partition, + 0, + init_method=lambda w: None, + stride=1, + return_master_weight=False, + rank=self.tp_rank, + world_size=self.tp_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter( + torch.empty(output_size_per_partition, dtype=config.params_dtype) + ) + with torch.no_grad(): + self.bias.zero_() + + def forward(self, x, *args, **kwargs): + return super().forward(x) + + +class WrappedLinear(Linear): + def __init__( + self, + config, + layer_number=None, + model_comm_pgs=None, + cp_comm_type=None, + tp_group=None, + tp_comm_buffer_name=None, + gather_output=False, + ): + super().__init__( + input_size=config.hidden_size, + output_size=config.hidden_size, + config=config, + init_method=config.init_method, + bias=False, + gather_output=gather_output, + skip_bias_add=False, + tp_comm_buffer_name=tp_comm_buffer_name, + tp_group=tp_group, + ) + + def forward(self, x, *args, **kwargs): + return super().forward(x) + + +class WrappedTELinear(TELinear): + # TODO: docstring + def __init__( + self, + config, + layer_number=None, # unused + model_comm_pgs=None, # unused + cp_comm_type=None, # unused + tp_group=None, # unused + tp_comm_buffer_name=None, + gather_output=False, # unused + ): + super().__init__( + input_size=config.hidden_size, + output_size=config.hidden_size, + parallel_mode="duplicated", + # parallel_mode=None, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + tp_comm_buffer_name=tp_comm_buffer_name, + is_expert=False, + ) + + def forward(self, x, *args, **kwargs): + return super().forward(x) + + +class WrappedMambaMixer(MambaMixer): + def __init__(self, *args, cp_comm_type: Optional[str] = None, **kwargs): + # ignoring cp_comm_type + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ) -> Tuple[Tensor, Tensor]: + result = super().forward(hidden_states, inference_context=inference_context) + # Ensure we return a tuple of two tensors + assert isinstance(result, tuple) and len(result) == 2 + return result + + +# NOTE: new method +def get_layer_spec_for_layer( + block_params: PuzzletronTransformerBlockConfig, + config: PuzzletronHeterogeneousTransformerConfig, + use_transformer_engine: bool, + normalization: Optional[str] = None, + qk_l2_norm: Optional[bool] = False, +) -> ModuleSpec: + # this part is copied from megatron.core.models.gpt.gpt_layer_specs.get_gpt_decoder_block_spec() + if use_transformer_engine: + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=block_params.mlp.num_moe_experts, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=block_params.attention.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + qk_l2_norm=qk_l2_norm, + use_kitchen=config.use_kitchen, + # use_te_activation_func=config.use_te_activation_func, #TODO: part of megatron 0.14 version. check if this is needed now. + ) + else: + layer_spec = get_gpt_layer_local_spec( + num_experts=block_params.mlp.num_moe_experts, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=block_params.attention.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + normalization=normalization, + qk_l2_norm=qk_l2_norm, + use_kitchen=config.use_kitchen, + ) + if block_params.attention.no_op: + layer_spec.submodules.input_layernorm = IdentityOp + layer_spec.submodules.self_attn_bda = IdentityFuncOp + layer_spec.submodules.self_attention = ModuleSpec(module=IdentityOp) + elif block_params.attention.replace_with_linear: + layer_spec.submodules.self_attention = ModuleSpec( + module=WrappedTENormLinear if use_transformer_engine else WrappedLinear, + params={"tp_comm_buffer_name": "linear_attn"}, + ) + elif block_params.attention.is_mamba: + mamba_mixer_params = dict( + d_model=config.hidden_size, + d_conv=config.d_conv, + expand=config.expand, + D_has_hdim=config.D_has_hdim, + rmsnorm=config.rmsnorm, + norm_before_gate=config.norm_before_gate, + dt_min=config.dt_min, + dt_max=config.dt_max, + dt_scale=config.dt_scale, + bias=config.bias, + conv_bias=config.conv_bias, + chunk_size=config.chunk_size, + ) + layer_spec.submodules.self_attention = ModuleSpec( + module=WrappedMambaMixer, + params=mamba_mixer_params, + submodules=MambaMixerSubmodules( + in_proj=( + TELayerNormColumnParallelLinear + if use_transformer_engine + else ColumnParallelLinear + ), + out_proj=TERowParallelLinear if use_transformer_engine else RowParallelLinear, + ), + ) + + if block_params.mlp.no_op: + layer_spec.submodules.pre_mlp_layernorm = IdentityOp + layer_spec.submodules.mlp_bda = IdentityFuncOp + layer_spec.submodules.mlp = ModuleSpec(module=IdentityOp) + elif block_params.mlp.replace_with_linear: + layer_spec.submodules.mlp = ModuleSpec( + module=WrappedTENormLinear if use_transformer_engine else WrappedLinear, + params={"tp_comm_buffer_name": "linear_mlp"}, + ) + + layer_spec.submodules.sharded_state_dict_keys_map = _get_sharded_state_dict_keys_map( + block_params, use_transformer_engine + ) + return layer_spec + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py#L168 +def get_gpt_heterogeneous_layer_spec_puzzletron( + config: PuzzletronHeterogeneousTransformerConfig, + use_transformer_engine: bool, + normalization: Optional[str] = None, + qk_l2_norm: Optional[bool] = False, + vp_stage: Optional[int] = None, +) -> TransformerBlockSubmodules: + """Generate heterogeneous layer specifications for Puzzletron transformer models. + + This function is a specialized version of the original Megatron Core + `get_gpt_heterogeneous_layer_spec` function, adapted for Puzzletron's specific + heterogeneous transformer architecture requirements. + + Key differences from the original: + - **Signature**: Adds `normalization` and `qk_l2_norm` parameters, removes `pp_rank` + - **Architecture**: Uses `get_layer_spec_for_layer()` helper for modular layer creation + - **Pipeline Parallel**: Enhanced with `pipeline_model_parallel_layout` support + - **Configuration**: Uses `PuzzletronHeterogeneousTransformerConfig` with Mamba parameters + - **Layer Norm**: Simplified to `TENorm` vs `LNImpl` (removes `WrappedTorchNorm` complexity) + - **Features**: Supports Mamba layers, custom attention types, and advanced parallelization + + Args: + config: Puzzletron heterogeneous transformer configuration + use_transformer_engine: Whether to use Transformer Engine optimizations + normalization: Optional normalization type override + qk_l2_norm: Whether to apply L2 normalization to QK matrices + vp_stage: Virtual pipeline stage for advanced parallelization + + Returns: + TransformerBlockSubmodules: Complete layer specification for the heterogeneous model + """ + # Create the layer specs for the model. + layer_specs = [ + get_layer_spec_for_layer( + block_params, config, use_transformer_engine, normalization, qk_l2_norm + ) + for block_params in config.per_block_parameters + ] + + # Slice the layer specs to only include the layers that are built in this pipeline stage. + # Note: MCore layer_number starts at 1 + num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) + + if config.pipeline_model_parallel_layout is not None: + local_layer_specs = [ + layer_specs[layer_id] + for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list( + layer_type=LayerType.decoder, vp_stage=vp_stage + ) + ] + else: + offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + local_layer_specs = layer_specs[offset : offset + num_layers_to_build] + + if use_transformer_engine: + layer_norm_impl = TENorm + else: + layer_norm_impl = LNImpl + + # Block spec. + block_spec = TransformerBlockSubmodules( + layer_specs=local_layer_specs, layer_norm=layer_norm_impl + ) + + return block_spec + + +# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/gpt_builders.py#L23 +def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): + """Build a GPT model with Puzzletron's heterogeneous transformer architecture. + + This function is a specialized version of the original Megatron-LM `gpt_builder` function, + adapted for Puzzletron's heterogeneous transformer architecture requirements. + + Key differences from the original: + - **Simplified**: Focuses exclusively on heterogeneous models (rejects legacy, spec-based, MoE, MTP) + - **Configuration**: Only supports args-based config (removes YAML complexity) + - **Layer Spec**: Uses single `get_gpt_heterogeneous_layer_spec_puzzletron` function + - **Error Handling**: Explicit error messages for unsupported features + - **Logging**: Removes debug logging for cleaner implementation + + Args: + args: Command-line arguments namespace containing model configuration parameters + pre_process: Whether to include pre-processing layers + post_process: Whether to include post-processing layers + vp_stage: Virtual pipeline stage for advanced parallelization + config: Optional pre-configured transformer config (if None, created from args) + + Returns: + GPTModel: Configured GPT model with heterogeneous transformer architecture + + Raises: + ValueError: If legacy models, spec-based models, or MTP are requested + """ + assert config is not None, "config is required" + if args.use_legacy_models: + raise ValueError("Legacy models are not supported") + if args.spec is not None: + raise ValueError("Spec is not supported") + use_te = args.transformer_impl == "transformer_engine" + transformer_layer_spec = get_gpt_heterogeneous_layer_spec_puzzletron( + config, + use_te, + normalization=args.normalization, + qk_l2_norm=args.qk_l2_norm, + vp_stage=vp_stage, + ) + mtp_block_spec = None + if args.mtp_num_layers is not None: + raise ValueError("MTP is not supported") + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + mtp_block_spec=mtp_block_spec, + vp_stage=vp_stage, + ) + + return model From 8c84fee21b3f368b6dd470b3ab1d818d76801e34 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 3 Mar 2026 07:52:42 -0800 Subject: [PATCH 32/62] [CI] Update to only run puzzletron tests Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/workflows/example_tests.yml | 83 ++----------------------- .github/workflows/gpu_tests.yml | 18 +++--- pyproject.toml | 2 +- tests/examples/puzzletron/test_dummy.py | 18 ++++++ tox.ini | 8 +++ 5 files changed, 40 insertions(+), 89 deletions(-) create mode 100644 tests/examples/puzzletron/test_dummy.py diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index f5844083ae..ab9c88346c 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -56,38 +56,6 @@ jobs: match_pattern: "^DCO$|^linux$" # Wait for DCO and Unit tests / linux to pass delay: 300s - ##### PyTorch Example Tests (speculative_decoding requires 26.01 image) ##### - torch-pr: - needs: [check-file-changes, wait-checks] - if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' - strategy: &torch_strategy - fail-fast: false - matrix: - example: [llm_distill, llm_qat, llm_sparsity] - include: - - example: speculative_decoding - docker_image: "26.01" - uses: ./.github/workflows/_example_tests_runner.yml - secrets: inherit - with: - docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3" - example: ${{ matrix.example }} - timeout_minutes: 30 - pip_install_extras: "[hf,dev-test]" - runner: linux-amd64-gpu-h100-latest-1 - - torch-non-pr: - if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} - strategy: *torch_strategy - uses: ./.github/workflows/_example_tests_runner.yml - secrets: inherit - with: - docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3" - example: ${{ matrix.example }} - timeout_minutes: 30 - pip_install_extras: "[hf,dev-test]" - runner: linux-amd64-gpu-rtxpro6000-latest-2 - ##### TensorRT-LLM Example Tests ##### trtllm-pr: needs: [check-file-changes, wait-checks] @@ -95,69 +63,26 @@ jobs: strategy: fail-fast: false matrix: - example: [llm_ptq, vlm_ptq] + example: [puzzletron] uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc5" example: ${{ matrix.example }} - pip_install_extras: "[hf,dev-test]" - runner: linux-amd64-gpu-rtxpro6000-latest-1 - - trtllm-non-pr: - if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} - strategy: - fail-fast: false - matrix: - example: [llm_autodeploy, llm_eval, llm_ptq, vlm_ptq] - uses: ./.github/workflows/_example_tests_runner.yml - secrets: inherit - with: - docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc5" - example: ${{ matrix.example }} - pip_install_extras: "[hf,dev-test]" - runner: linux-amd64-gpu-rtxpro6000-latest-2 - - ##### ONNX/TensorRT Example Tests ##### - onnx-pr: - needs: [check-file-changes, wait-checks] - if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' - strategy: &onnx_strategy - fail-fast: false - matrix: - example: [diffusers, torch_onnx] - uses: ./.github/workflows/_example_tests_runner.yml - secrets: inherit - with: - docker_image: "nvcr.io/nvidia/tensorrt:26.01-py3" - example: ${{ matrix.example }} - pip_install_extras: "[all,dev-test]" - runner: linux-amd64-gpu-l4-latest-1 - - onnx-non-pr: - if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} - strategy: *onnx_strategy - uses: ./.github/workflows/_example_tests_runner.yml - secrets: inherit - with: - docker_image: "nvcr.io/nvidia/tensorrt:26.01-py3" - example: ${{ matrix.example }} - pip_install_extras: "[all,dev-test]" + pip_install_extras: "[hf,puzzletron,dev-test]" runner: linux-amd64-gpu-rtxpro6000-latest-2 ##### Required Check for PR ##### example-pr-required-check: # Run even if example tests are skipped if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }} - needs: [check-file-changes, torch-pr, trtllm-pr, onnx-pr] + needs: [check-file-changes, trtllm-pr] runs-on: ubuntu-latest steps: - name: Required GPU tests did not succeed if: | needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && ( - needs.torch-pr.result != 'success' || - needs.trtllm-pr.result != 'success' || - needs.onnx-pr.result != 'success' + needs.trtllm-pr.result != 'success' )) run: exit 1 diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index 059da06c2d..843bc4d932 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -63,16 +63,16 @@ jobs: fail-fast: false matrix: include: - - example: gpu - timeout: 60 - container_image: pytorch:26.01-py3 - - example: gpu-megatron - timeout: 90 - container_image: pytorch:26.01-py3 - - example: gpu-trtllm + - example: gpu-puzzletron timeout: 30 - container_image: tensorrt-llm/release:1.3.0rc5 - runs-on: linux-amd64-gpu-rtxpro6000-latest-1 + container_image: pytorch:26.01-py3 + # - example: gpu-megatron + # timeout: 90 + # container_image: pytorch:26.01-py3 + # - example: gpu-trtllm + # timeout: 30 + # container_image: tensorrt-llm/release:1.3.0rc5 + runs-on: linux-amd64-gpu-rtxpro6000-latest-2 timeout-minutes: ${{ matrix.timeout }} container: &gpu_container image: nvcr.io/nvidia/${{ matrix.container_image }} diff --git a/pyproject.toml b/pyproject.toml index 9bf3333b35..dea22b280a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ dev-test = [ "tox-current-env>=0.0.12", ] # Compound extras via self-references -all = ["nvidia-modelopt[onnx,hf]"] +all = ["nvidia-modelopt[onnx,hf,puzzletron]"] dev = ["nvidia-modelopt[all,dev-docs,dev-lint,dev-test]"] [project.urls] diff --git a/tests/examples/puzzletron/test_dummy.py b/tests/examples/puzzletron/test_dummy.py new file mode 100644 index 0000000000..d07694471a --- /dev/null +++ b/tests/examples/puzzletron/test_dummy.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_dummy(): + assert True diff --git a/tox.ini b/tox.ini index 28467a7159..e72c9f7ba8 100644 --- a/tox.ini +++ b/tox.ini @@ -70,6 +70,14 @@ commands = # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" python -m pytest tests/gpu +[testenv:cuda13-gpu-puzzletron] +commands_pre = + # Install deps here so that it gets installed even in --current-env + pip install -e .[hf,puzzletron,dev-test] +commands = + # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" + python -m pytest tests/gpu/torch/puzzletron + [testenv:cuda13-gpu-trtllm] commands_pre = # Install deps here so that it gets installed even in --current-env From 5f77c811f8ac91eadc4127ff3aa652afa6a11e4b Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:24:46 +0530 Subject: [PATCH 33/62] Pin torchprofile==0.0.4 to fix CI Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 37400d92ca..d32a1e18ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ dev-test = [ "pytest-instafail", "pytest-timeout", "timm", - "torchprofile>=0.0.4", # optional dependency for modelopt.torch + "torchprofile==0.0.4", # optional dependency for modelopt.torch "torchvision", "torch-geometric", "tox>4.18", From 82df595a34a7606914e1395687f74c701c5324c5 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 16:03:54 +0100 Subject: [PATCH 34/62] Add anymodel-core to feature/puzzletron (#974) ### What does this PR do? - Add converter, model_descriptor, puzzformer, and llama model support ## Summary by CodeRabbit * **New Features** * CLI to convert HF models to AnyModel format; unified package exports; Llama support; pruning toolkit (FFN, KV-heads, MoE expert removal) with multiple init strategies and runtime hooks; per-layer patching and no-op primitives; improved checkpoint export flow. * **Documentation** * Comprehensive AnyModel README and SPDX/license headers added. * **Tests** * Expanded parameterized end-to-end tests and new tokenizer/model config resources. * **Chores** * Package initializers consolidated public API; lightweight dummy modules for testing. --------- Signed-off-by: Daniel Korzekwa --- .../nas/plugins/megatron_hooks/base_hooks.py | 380 +++++++++- modelopt/torch/puzzletron/anymodel/README.md | 204 +++++ .../torch/puzzletron/anymodel/__init__.py | 64 ++ .../puzzletron/anymodel/converter/__init__.py | 19 + .../anymodel/converter/convert_any_model.py | 68 ++ .../anymodel/converter/converter.py | 235 ++++++ .../anymodel/converter/converter_factory.py | 75 ++ .../anymodel/model_descriptor/__init__.py | 18 + .../model_descriptor/model_descriptor.py | 216 ++++++ .../model_descriptor_factory.py | 122 +++ .../puzzletron/anymodel/models/__init__.py | 24 + .../anymodel/models/llama/__init__.py | 19 + .../anymodel/models/llama/llama_converter.py | 53 ++ .../models/llama/llama_model_descriptor.py | 131 ++++ .../anymodel/puzzformer/__init__.py | 30 + .../puzzletron/anymodel/puzzformer/no_op.py | 79 ++ .../puzzletron/anymodel/puzzformer/utils.py | 122 +++ .../decilm/deci_lm_hf_code/block_config.py | 97 +-- .../nas/plugins/puzzletron_nas_plugin.py | 38 +- .../pruning/expert_removal_pruning_mixin.py | 239 ++++++ .../pruning/ffn_intermediate_pruning_mixin.py | 102 +++ .../pruning/kv_heads_pruning_mixin.py | 127 ++++ .../torch/puzzletron/pruning/pruning_ckpts.py | 94 +-- .../torch/puzzletron/pruning/pruning_mixin.py | 73 ++ .../torch/puzzletron/pruning/pruning_utils.py | 652 ++++++++++++++++ .../tools/bypassed_training/child_init.py | 704 ++++-------------- .../puzzletron/tools/checkpoint_utils_hf.py | 171 +++-- .../torch/puzzletron/utils/dummy_modules.py | 75 ++ tests/_test_utils/torch/puzzletron/utils.py | 143 +++- ..._convert_llama3_config_to_decilm_config.py | 50 -- .../nas/plugins/test_nas_convert.py | 19 +- .../puzzletron/nas/plugins/test_nas_search.py | 10 +- .../llama_3_1_8b_instruct-attn-pruning.yaml | 107 +++ .../llama_3_1_8b_instruct.yaml | 107 +++ .../pruning/attn_pruning.yaml | 16 + .../pruning/ffn_pruning.yaml | 18 + .../pruning/hidden_dim_pruning.yaml | 15 + .../pruning/pruning_defaults.yaml | 33 + .../validate_model_defaults.yaml | 15 + .../validate_solutions_defaults.yaml | 10 + .../llama_3_1_8b_instruct/config.json | 38 + .../tokenizer/special_tokens_map.json | 16 + .../resources/tokenizer/tokenizer.json | 212 ++++++ .../resources/tokenizer/tokenizer_config.json | 13 + .../resources/tokenizer/truncate_tokenizer.py | 62 ++ tests/gpu/torch/puzzletron/test_puzzletron.py | 303 ++++++-- 46 files changed, 4504 insertions(+), 914 deletions(-) create mode 100644 modelopt/torch/puzzletron/anymodel/README.md create mode 100644 modelopt/torch/puzzletron/anymodel/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/converter/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py create mode 100644 modelopt/torch/puzzletron/anymodel/converter/converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/converter/converter_factory.py create mode 100644 modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/llama/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py create mode 100644 modelopt/torch/puzzletron/anymodel/puzzformer/utils.py create mode 100644 modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py create mode 100644 modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py create mode 100644 modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py create mode 100644 modelopt/torch/puzzletron/pruning/pruning_mixin.py create mode 100644 modelopt/torch/puzzletron/pruning/pruning_utils.py create mode 100644 modelopt/torch/puzzletron/utils/dummy_modules.py delete mode 100644 tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json create mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json create mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json create mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json create mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py index 56436acfdd..7cd7214443 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# mypy: ignore-errors """Forward hooks for activation-based importance estimation.""" import gc @@ -26,6 +27,7 @@ from torch import nn import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 from modelopt.torch.puzzletron.tools.logger import aprint from modelopt.torch.puzzletron.tools.robust_json import json_dump @@ -150,7 +152,8 @@ def dump_activations_logs( torch.save(activations_log, activations_log_path) if rank == 0: - args.activation_hooks_kwargs.pop("model") + if args.activation_hooks_kwargs is not None: + args.activation_hooks_kwargs.pop("model", None) json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") dist.barrier() @@ -822,3 +825,378 @@ def _save_channel_importance_results( aprint(f"Score range: {avg_scores.min():.4f} to {avg_scores.max():.4f}") aprint(f"Score mean: {avg_scores.mean():.4f}") aprint(f"Score std: {avg_scores.std():.4f}") + + +class RemoveExpertsIndependentHook(ForwardHook, ABC): + """Base hook for measuring expert importance in Mixture-of-Experts models. + + This hook measures how much removing each expert affects the model output + by comparing outputs with and without each expert. + """ + + def __init__(self, moe: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + moe: The MoE module to analyze + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.moe = moe + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.num_local_experts = block_config.ffn.moe.num_local_experts + self.num_experts_per_tok = block_config.ffn.moe.num_experts_per_tok + # tensor of zeros of size num experts + self.diffs = ["mse", "cosine"] + some_param = next(self.moe.parameters()) + self.diffs = { + k: torch.zeros( + size=(self.num_local_experts,), dtype=torch.float32, device=some_param.device + ) + for k in self.diffs + } + self.call_count = 0 + + @abstractmethod + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for measuring expert importance. + + This method is called twice per forward pass: + 1. First call (router_logits=None): Compute original routing and expert outputs + 2. Second call (router_logits provided): Re-run with modified logits (expert disabled) + + Args: + hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) + router_logits: Optional pre-computed router logits. If None, compute from hidden_states. + + Returns: + tuple of (router_logits, routed_experts): + - router_logits: Shape (num_tokens, num_local_experts) + - routed_experts: Shape (num_tokens, hidden_dim) + """ + raise NotImplementedError + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that measures expert importance.""" + hidden_states = args[0] + router_logits, original_routed_out = self.get_router_logits_and_routed_experts( + hidden_states + ) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + original_routed_out = original_routed_out.view(-1, original_routed_out.shape[-1]) + + _, router_indices = torch.topk(router_logits, self.num_experts_per_tok, dim=-1) + self.call_count += 1 + + for i_expert in range(self.num_local_experts): + expert_mask = router_indices == i_expert + is_token_routed_to_this_expert = expert_mask.any(dim=-1) + + num_tokens_displaced = is_token_routed_to_this_expert.sum() + if num_tokens_displaced == 0: + continue + num_total_tokens = is_token_routed_to_this_expert.numel() + + relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] + + router_logits_without_i = router_logits.clone() + router_logits_without_i[..., i_expert] = -float("inf") # disable expert i + router_logits_without_i = router_logits_without_i[is_token_routed_to_this_expert, :] + _, routed_out_without_i = self.get_router_logits_and_routed_experts( + relevant_hidden_states, router_logits_without_i + ) + + relevant_tokens_original_out = original_routed_out[is_token_routed_to_this_expert, :] + self.diffs["mse"][i_expert] += ( + nn.functional.mse_loss( + relevant_tokens_original_out, routed_out_without_i, reduction="mean" + ) + * num_tokens_displaced + / num_total_tokens + ) + self.diffs["cosine"][i_expert] += ( + -nn.functional.cosine_similarity( + relevant_tokens_original_out, routed_out_without_i, dim=-1 + ).mean() + * num_tokens_displaced + / num_total_tokens + ) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format.""" + expert_ranks_mse = torch.argsort(self.diffs["mse"]) + expert_ranks_cosine = torch.argsort(self.diffs["cosine"]) + return { + "expert_ranks_mse": expert_ranks_mse.cpu(), + "expert_ranks_cosine": expert_ranks_cosine.cpu(), + "cosine_diffs": (self.diffs["cosine"] / self.call_count).cpu(), + "mse_diffs": (self.diffs["mse"] / self.call_count).cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert importance scores.""" + return self.diffs["mse"] + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "diffs_mse": self.diffs["mse"].cpu(), + "diffs_cosine": self.diffs["cosine"].cpu(), + "call_count": self.call_count, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.diffs["mse"] = state_dict["diffs_mse"].to(self.diffs["mse"].device) + self.diffs["cosine"] = state_dict["diffs_cosine"].to(self.diffs["cosine"].device) + self.call_count = state_dict["call_count"] + + +class NemotronHRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for NemotronH models.""" + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for NemotronH MoE. + + Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts. + """ + orig_shape = hidden_states.shape + # NemotronHMOE.gate forward, copied to extract router_logits + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if router_logits is None: + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), self.moe.gate.weight.type(torch.float32) + ) + router_logits = router_logits.sigmoid() + router_logits = router_logits + self.moe.gate.e_score_correction_bias.unsqueeze(0) + + topk_indices = self._get_topk_indices_without_correction_bias(router_logits) + topk_weights = router_logits.gather(1, topk_indices) + if self.moe.gate.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.moe.gate.routed_scaling_factor + # Routed experts forward + hidden_states = self.moe.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + return router_logits, hidden_states + + @torch.no_grad() + def _get_topk_indices_without_correction_bias(self, scores: torch.Tensor) -> torch.Tensor: + """Get topk indices without correction bias. + + Same as NemotronHMOE.gate.get_topk_indices but without adding e_score_correction_bias. + """ + group_scores = ( + scores.view( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.moe.gate.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .reshape(-1, self.moe.gate.n_routed_experts) + ) + scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.moe.gate.top_k, dim=-1, sorted=False)[1] + return topk_indices + + +class RankedChoiceVotingHook(ForwardHook): + """Hook for ranking experts using ranked choice voting algorithm. + + This hook tracks router decisions and uses ranked choice voting to determine + which experts are least important (can be pruned first). + """ + + def __init__(self, router: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + router: The router module (typically nn.Linear) + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.router_argsort: list[torch.Tensor] = [] + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.top_k = block_config.ffn.moe.num_experts_per_tok + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that records router decisions. + + Args: + module: The router module + args: Tuple with one tensor entry (B, T, I) + output: Router logits of shape (B, T, E) + """ + router_logits = output[0] if isinstance(output, tuple) else output + num_experts = router_logits.shape[-1] + router_argsort = torch.argsort(router_logits, dim=-1, descending=True) + router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu() + self.router_argsort.append(router_argsort) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format using ranked choice voting.""" + router_argsort = torch.concat(self.router_argsort, dim=0) + num_tokens, num_experts = router_argsort.shape + + expert_ranks = torch.full((num_experts,), -1) + expert_counts_at_pruning_time = {} + + expert_kept_per_iteration: list[list[int]] = [] + expert_counts_per_iteration: list[dict[int, int]] = [] + + for rank in range(num_experts): + ids, counts = router_argsort[:, : self.top_k].unique(return_counts=True) + ids = ids.tolist() + counts = counts.tolist() + expert_counts = dict(zip(ids, counts)) + + expert_kept_per_iteration.append(ids) + expert_counts_per_iteration.append(expert_counts) + + least_popular_expert, min_count = min(expert_counts.items(), key=lambda tup: tup[1]) + + expert_ranks[least_popular_expert] = rank + expert_counts_at_pruning_time[least_popular_expert] = min_count + aprint(f"#{rank}: router_argsort shape = {router_argsort.shape}") + router_argsort = router_argsort[router_argsort != least_popular_expert].view( + num_tokens, -1 + ) + + zero_shot_expert_counts = torch.zeros((num_experts,), dtype=torch.long) + for expert_id, expert_counts_val in expert_counts_per_iteration[0].items(): + zero_shot_expert_counts[expert_id] = expert_counts_val + + # Compute zero-shot expert ranks (double argsort converts counts to rank positions) + zero_shot_expert_ranks = torch.argsort(torch.argsort(zero_shot_expert_counts)) + + aprint("Done: Returning hook metadata.") + return { + "expert_ranks": expert_ranks, + "zero_shot_expert_ranks": zero_shot_expert_ranks, + "expert_counts_at_pruning_time": expert_counts_at_pruning_time, + "expert_counts_per_iteration": expert_counts_per_iteration, + "top_k": self.top_k, + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert ranks.""" + if not self.router_argsort: + return torch.tensor([]) + router_argsort = torch.concat(self.router_argsort, dim=0) + return router_argsort[:, 0].float() + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "router_argsort": [tensor.cpu().clone() for tensor in self.router_argsort], + "top_k": self.top_k, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.router_argsort = [tensor.cpu() for tensor in state_dict["router_argsort"]] + self.top_k = state_dict["top_k"] + + def get_progress_info(self) -> dict: + """Get progress information.""" + return { + "num_batches_processed": len(self.router_argsort), + "total_tokens_processed": sum(tensor.shape[0] for tensor in self.router_argsort) + if self.router_argsort + else 0, + } + + +class RankedChoiceVotingHookNemotronH(RankedChoiceVotingHook): + """Ranked choice voting hook for NemotronH models. + + In NemotronH, router_logits is an internal temporary state that never leaves + the forward() function. We reconstruct router_logits from the input hidden_states. + """ + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that reconstructs router logits from hidden states.""" + hidden_states = args[0] + hidden_states = hidden_states.view(-1, module.config.hidden_size) + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), module.weight.type(torch.float32) + ) + super().__call__(module, args, router_logits) + + +class Qwen3VLRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for Qwen3-VL models. + + TODO: Implement get_router_logits_and_routed_experts based on Qwen3-VL MoE forward pass. + """ + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for Qwen3-VL MoE. + + Note: This is a placeholder implementation. Implement based on Qwen3VLMoeSparseMoe forward. + """ + batch_size = ( + hidden_states.shape[0] * hidden_states.shape[1] + if hidden_states.ndim > 2 + else hidden_states.shape[0] + ) + router_logits_out = torch.zeros( + batch_size, self.num_local_experts, device=hidden_states.device + ) + routed_experts = hidden_states.view(-1, hidden_states.shape[-1]) + return router_logits_out, routed_experts + + +class GptOssRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for GPT-OSS models. + + TODO: Implement get_router_logits_and_routed_experts based on GPT-OSS MoE forward pass. + This is a placeholder implementation that allows the framework to run. + """ + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for GPT-OSS MoE. + + Note: This is a placeholder implementation. For proper expert scoring, + implement based on GptOssSparseMoeBlock forward pass. + + Args: + hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) + router_logits: Optional pre-computed router logits + + Returns: + tuple of (router_logits, routed_experts): + - router_logits: Shape (num_tokens, num_local_experts) - zeros as placeholder + - routed_experts: Original hidden states (no-op) + """ + batch_size = ( + hidden_states.shape[0] * hidden_states.shape[1] + if hidden_states.ndim > 2 + else hidden_states.shape[0] + ) + router_logits_out = torch.zeros( + batch_size, self.num_local_experts, device=hidden_states.device + ) + routed_experts = hidden_states.view(-1, hidden_states.shape[-1]) + return router_logits_out, routed_experts diff --git a/modelopt/torch/puzzletron/anymodel/README.md b/modelopt/torch/puzzletron/anymodel/README.md new file mode 100644 index 0000000000..9dea9d45f9 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/README.md @@ -0,0 +1,204 @@ +# AnyModel Guide + +This guide explains how to add support for new models in the Puzzletron pipeline. + +## Convert model + +Convert a HuggingFace model to Puzzletron format. + +Step 1: Create Model Descriptor + +Extend `ModelDescriptor` and implement `layer_name_predicates()` to define regex patterns for grouping weights into subblocks (embeddings, lm_head, block_N_ffn, block_N_attention). + +Key points: + +- Find weight names on the model's HuggingFace page → click "Files info" to see the safetensors structure with all tensor names (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)) + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) + +Step 2: Create Converter + +Extend `Converter` and implement `create_block_configs_from_main_config()` to create per-layer BlockConfigs from the HuggingFace config. + +Key points: + +- Import correct HuggingFace config class (e.g., `MistralConfig`, `LlamaConfig`, `Qwen2Config`). Find it in the transformers source: `github.com/huggingface/transformers/tree/main/src/transformers/models//configuration_.py` + +See example: [llama_converter.py](models/llama/llama_converter.py) + +Step 3: Create `models//__init__.py` + +Export descriptor and converter classes: + +```python +from models.._model_descriptor import MyModelDescriptor +from models.._converter import MyConverter +``` + +Step 4: Register in `models/__init__.py` + +Add import to trigger factory registration: + +```python +from models. import * +``` + +## Usage + +```python +from modelopt.torch.puzzletron.anymodel import convert_model + +convert_model( + input_dir="path/to/hf_checkpoint", + output_dir="path/to/puzzletron_checkpoint", + converter="model_name", +) +``` + +## Compress model + +Run pruning and compression on a Puzzletron model. + +Step 1: Implement ModelDescriptor methods for compression + +Add to your `ModelDescriptor`: + +- `decoder_layer_cls()` - return the decoder layer class(es) to patch for heterogeneous config support +- `block_config_to_layer_overrides()` - map BlockConfig to layer override dict (see [details](#implementing-block_config_to_layer_overrides)) +- `init_rotary_embedding()` - reinitialize rotary embeddings after model loading (see [details](#implementing-init_rotary_embedding)) +- `input_embedding_name()` - return the name of the input embedding layer (see [details](#implementing-path-based-methods)) +- `output_embedding_name()` - return the name of the output embedding layer (see [details](#implementing-path-based-methods)) +- `layer_block_name()` - return the name pattern for decoder layers (see [details](#implementing-path-based-methods)) +- `final_norm_name()` - return the name of the final normalization layer (see [details](#implementing-path-based-methods)) +- `attn_no_op_post_init()` - replace attention sublayers with no-op modules +- `mlp_no_op_post_init()` - replace MLP sublayers with no-op modules + +Step 2: Create FFN Layer Descriptor + +Extend `FFNIntermediateLayerDescriptor` to define model-specific paths for FFN pruning hooks (`down_proj_name`, `ffn_prefix_name`, `linear_weight_names`). Derive values from your model's weight names in `layer_name_predicates()`. + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) → `LlamaFFNIntermediateLayerDescriptor` + +Step 3: Configure YAML files + +Update the main model config YAML: + +- Set `descriptor` to match the name used in `@ModelDescriptorFactory.register_decorator("your_model_name")` +- See example: [llama_3_1_8b_instruct.yaml](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml) + +Update pruning YAML files (`ffn_pruning.yaml`, `expert_pruning.yaml`, etc.): + +- Set `pruning_mixin._target_` to the appropriate mixin class +- Set `layer_descriptor._target_` to your layer descriptor class +- Set `hook_class` to the activation hook for scoring +- Set `target_layer` in `activation_hooks_kwargs` to the layer name for hook attachment +- See examples in [configs/llama_3_1_8b_instruct/pruning/](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/) + +## End-to-end example + +See [test_puzzletron.py](../../../../tests/gpu/torch/puzzletron/test_puzzletron.py) for a complete example that runs both convert and compression steps. + +--- + +## Advanced Topics + +## Pruning Configuration + +### Pruning YAML Structure + +Each pruning type has a YAML config with these key fields: + +```yaml +pruning_mixin: + _target_: pruning._pruning_mixin. + layer_descriptor: + _target_: models.. + +hook_class: ${get_object:utils.activation_hooks.hooks.} +activation_hooks_kwargs: + method: + target_layer: "" # e.g., "mlp.down_proj", "self_attn.o_proj" +``` + +| Field | Description | +|-------|-------------| +| `pruning_mixin._target_` | Mixin class that orchestrates this pruning type | +| `layer_descriptor._target_` | Model-specific class defining layer paths for hooks | +| `hook_class` | Activation hook class for importance scoring | +| `target_layer` | Layer name (relative to decoder block) where hooks attach | + +### Adding a New Hook Class + +1. **Implement the hook** in `modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`: + - Extend an existing hook base class (e.g., `RemoveExpertsIndependentHook`) + - Implement required methods (e.g., `get_router_logits_and_routed_experts`) + +2. **Register the hook** in the appropriate pruning mixin's `supported_hooks()`: + + For FFN pruning (`pruning/ffn_intermediate_pruning_mixin.py`): + + ```python + def supported_hooks(self) -> List[Type[ActivationsHook]]: + return [IndependentChannelContributionHook, IterativeChannelContributionHook, YourNewHook] + ``` + + For expert removal (`pruning/expert_removal_pruning_mixin.py`): + + ```python + def supported_hooks(self) -> List[Type[ActivationsHook]]: + return [RankedChoiceVotingHook, ..., YourNewHook] + ``` + +3. **Reference in YAML**: + + ```yaml + hook_class: ${get_object:utils.activation_hooks.hooks.YourNewHook} + ``` + +### Pruning Types Reference + +| Type | Mixin | Example Hooks | +|------|-------|---------------| +| FFN intermediate | [`FFNIntermediatePruningMixIn`](../pruning/ffn_intermediate_pruning_mixin.py) | [`IterativeChannelContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py), [`IndependentChannelContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | +| Expert removal | [`ExpertRemovalPruningMixIn`](../pruning/expert_removal_pruning_mixin.py) | [`NemotronHRemoveExpertsIndependentHook`](../../../nas/plugins/megatron_hooks/base_hooks.py), [`Qwen3VLRemoveExpertsIndependentHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | +| KV heads | [`KVHeadsPruningMixIn`](../pruning/kv_heads_pruning_mixin.py) | [`IndependentKvHeadContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | + +## Implementing `block_config_to_layer_overrides` + +Maps Puzzletron's [`BlockConfig`](../decilm/deci_lm_hf_code/block_config.py) fields to HuggingFace config attribute names. Only override attributes that change during pruning: + +| BlockConfig Field | HuggingFace Attribute (check `config.json`) | +|-------------------|---------------------------------------------| +| `attention.num_key_value_heads` | `num_key_value_heads` | +| `ffn.intermediate_size` | `intermediate_size` | +| `ffn.moe.num_local_experts` | `num_experts` or `n_routed_experts` (model-specific) | +| `ffn.moe.expert_intermediate_dim` | `moe_intermediate_size` | + +**Tip**: Check the model's `config.json` for exact attribute names - they vary between models. + +See examples: [qwen3_vl](models/qwen3_vl/qwen3_vl_model_descriptor.py), [nemotron_h](models/nemotron_h/nemotron_h_model_descriptor.py) + +--- + +## Implementing path-based methods + +These methods return paths derived from the model's weight names: + +- `input_embedding_name()`, `output_embedding_name()`, `layer_block_name()`, `final_norm_name()` + +Find them on the model's HuggingFace page → "Files info" → safetensors structure (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)). + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) + +--- + +## Implementing `init_rotary_embedding` + +Rotary embeddings are computed modules (not saved weights). After model sharding, they need re-initialization on the correct device/dtype. + +Look in `github.com/huggingface/transformers/tree/main/src/transformers/models//modeling_.py` for: + +- `class.*Rotary` — the rotary embedding class name and constructor arguments +- `self.rotary_emb` — the attribute path + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) diff --git a/modelopt/torch/puzzletron/anymodel/__init__.py b/modelopt/torch/puzzletron/anymodel/__init__.py new file mode 100644 index 0000000000..e1755a16d8 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/__init__.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""AnyModel: Architecture-agnostic model compression for HuggingFace models. + +This module provides a declarative approach to model compression that works with +any HuggingFace model without requiring custom modeling code. Instead of duplicating +HuggingFace modeling classes, AnyModel uses ModelDescriptors that define: + +1. Which decoder layer class(es) to patch for heterogeneous configs +2. How to map BlockConfig to layer-specific overrides +3. Weight name patterns for subblock checkpointing + +Example usage: + >>> from modelopt.torch.puzzletron.anymodel import convert_model + >>> convert_model( + ... input_dir="path/to/hf_checkpoint", + ... output_dir="path/to/anymodel_checkpoint", + ... converter="llama", + ... ) + +Supported models: + - llama: Llama 2, Llama 3, Llama 3.1, Llama 3.2 + - (more to come: qwen2, mistral_small, etc.) +""" + +# Import models to trigger factory registration +from modelopt.torch.puzzletron.anymodel import models # noqa: F401 +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory, convert_model +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import ( + MatchingZeros, + Same, + deci_x_patcher, + return_tuple_of_size, +) + +__all__ = [ + "Converter", + "ConverterFactory", + "ModelDescriptor", + "ModelDescriptorFactory", + "deci_x_patcher", + "MatchingZeros", + "Same", + "return_tuple_of_size", + "convert_model", +] diff --git a/modelopt/torch/puzzletron/anymodel/converter/__init__.py b/modelopt/torch/puzzletron/anymodel/converter/__init__.py new file mode 100644 index 0000000000..02903b817d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converters for transforming HuggingFace models to AnyModel format.""" + +from .convert_any_model import * +from .converter import * +from .converter_factory import * diff --git a/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py new file mode 100644 index 0000000000..889685c001 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Convert a HuggingFace model to AnyModel format.""" + +from pathlib import Path + +from modelopt.torch.puzzletron.anymodel.converter.converter import Converter +from modelopt.torch.puzzletron.anymodel.converter.converter_factory import ConverterFactory +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory + +__all__ = ["convert_model"] + + +def convert_model( + input_dir: str, + output_dir: str, + converter: Converter | str, +): + """Convert a HuggingFace model to AnyModel format. + + This function converts a HuggingFace checkpoint to the AnyModel format used + for compression. The conversion process: + + 1. Copies non-weight files (config, tokenizer, etc.) + 2. Creates block_configs for each layer + 3. Reorganizes weights into subblock checkpoints + + Args: + input_dir: Path to the input HuggingFace checkpoint directory. + output_dir: Path to the output AnyModel checkpoint directory. + converter: Either a converter name (e.g., "llama") or a Converter class. + + Example: + >>> convert_model( + ... input_dir="/path/to/Llama-3.1-8B-Instruct", + ... output_dir="/path/to/output/ckpts/teacher", + ... converter="llama", + ... ) + """ + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get descriptor and converter from factories (they use the same name) + descriptor = ModelDescriptorFactory.get(converter) + converter = ConverterFactory.get(converter) + + converter.convert(descriptor=descriptor, input_dir=input_dir, output_dir=output_dir) + + +if __name__ == "__main__": + from fire import Fire + + Fire(convert_model) diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter.py b/modelopt/torch/puzzletron/anymodel/converter/converter.py new file mode 100644 index 0000000000..5fdc92718c --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/converter.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import copy +import fnmatch +import json +import os +import shutil +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Dict, List + +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import PretrainedConfig +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config, save_model_config + +__all__ = ["Converter"] + + +class Converter(ABC): + """Base class for converting HuggingFace models to Puzzletron/AnyModel format.""" + + @staticmethod + def _get_weight_map(input_dir: Path) -> Dict[str, str]: + """Load weight map from checkpoint directory (supports both sharded and single-file models). + + Returns a dict mapping parameter names to their safetensors filenames. + """ + index_path = input_dir / "model.safetensors.index.json" + single_file_path = input_dir / "model.safetensors" + + if index_path.exists(): + # Sharded model + with open(index_path, "r") as f: + index = json.load(f) + return index["weight_map"] + elif single_file_path.exists(): + # Single file model - create a synthetic weight map + data = load_file(single_file_path) + return {name: "model.safetensors" for name in data.keys()} + else: + raise FileNotFoundError( + f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." + ) + + @classmethod + def convert_model_weights( + cls, input_dir: Path, output_dir: Path, descriptor: ModelDescriptor, num_hidden_layers: int + ): + """Convert model weights to subblock format.""" + param_to_file = Converter._get_weight_map(input_dir) + all_param_names = list(param_to_file.keys()) + + # Reverse map: file -> set of params + file_to_params = defaultdict(set) + for name, file in param_to_file.items(): + file_to_params[file].add(name) + + # Determine subblocks needed + subblocks = descriptor.get_weight_groups( + all_param_names, num_hidden_layers=num_hidden_layers + ) + + # Output directory + out_dir = output_dir / "subblocks_safetensors" + os.makedirs(out_dir, exist_ok=True) + + # New weight index + new_index = {"metadata": {"format": "pt"}, "weight_map": {}} + + for subblock, param_names in tqdm(subblocks.items(), desc="Processing subblocks"): + param_files = set(param_to_file[name] for name in param_names) + tensors = {} + + # Load only needed files for this subblock + for file in param_files: + data = load_file(os.path.join(input_dir, file)) + for name in param_names: + if param_to_file[name] == file and name in data: + converted_name = cls.convert_weight_name(name) + # Convert MoE packed tensors if quantized is mxfp4 //gpt-oss-20b + if getattr(cls, "quantized", None) == "mxfp4": + if name.endswith("_blocks"): + converted_name = converted_name.replace("_blocks", "") + tensors[converted_name] = convert_moe_packed_tensors( + data[converted_name + "_blocks"], + data[converted_name + "_scales"], + ) + elif name.endswith("_scales"): + continue + else: + tensors[converted_name] = data[name] + else: + tensors[converted_name] = data[name] + + # Save this subblock + print(f"\n✅ Group: {subblock} ({len(tensors)} layers)") + for layer in tensors.keys(): + print(f" - {layer}") + + subblock_file = f"{subblock}.safetensors" + save_file(tensors, os.path.join(out_dir, subblock_file)) + + # Update index + for new_name in tensors.keys(): + new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}" + + # Save new index file + with (output_dir / "model.safetensors.index.json").open("w") as f: + json.dump(new_index, f, indent=2) + + print(f"✅ Finished saving subblocks and index to {output_dir}") + + @classmethod + def convert_configs_in_dirs( + cls, + input_dir: Path, + output_dir: Path, + ): + """Convert config and add block_configs.""" + config = load_model_config(input_dir) + + block_configs = cls.create_block_configs_from_main_config(config) + out_config = copy.deepcopy(config) + out_config.block_configs = block_configs + + save_model_config(out_config, output_dir) + return out_config + + @staticmethod + def copy_checkpoint_files(input_dir: Path, output_dir: Path): + """Copy checkpoint files except model weights (which will be converted).""" + ignore_patterns = [ + "model-*.safetensors", + "model.safetensors", + "model.safetensors.index.json", + "subblocks_safetensors", + ] + + def ignore_func(dir, files): + ignored = set() + for pattern in ignore_patterns: + ignored.update(fnmatch.filter(files, pattern)) + return ignored + + shutil.copytree(str(input_dir), str(output_dir), ignore=ignore_func, dirs_exist_ok=True) + + @classmethod + def convert( + cls, + descriptor: ModelDescriptor, + input_dir: Path, + output_dir: Path, + ): + """Convert a HuggingFace model to AnyModel format. + + Args: + descriptor: Model descriptor for the model type. + input_dir: Path to the input HuggingFace checkpoint. + output_dir: Path to the output AnyModel checkpoint. + """ + cls.copy_checkpoint_files(input_dir, output_dir) + config = cls.convert_configs_in_dirs(input_dir, output_dir) + cls.convert_model_weights( + input_dir, output_dir, descriptor=descriptor, num_hidden_layers=config.num_hidden_layers + ) + + @staticmethod + @abstractmethod + def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: + """Create per-layer BlockConfig list from a HuggingFace model config. + + This method extracts layer-specific parameters (e.g., intermediate_size, + num_key_value_heads) from the main model config and creates a BlockConfig + for each layer. These BlockConfigs enable layer-specific pruning and + modifications during the compression pipeline. + + Args: + config: HuggingFace PretrainedConfig (e.g., LlamaConfig, Qwen2Config) + + Returns: + List of BlockConfig, one per hidden layer. Each BlockConfig contains: + - AttentionConfig: attention settings (no_op, num_key_value_heads) + - FFNConfig: FFN settings (no_op, intermediate_size) + + Example: + For a model with uniform layers (e.g., Llama): + return [BlockConfig(...)] * config.num_hidden_layers + + For a model with heterogeneous layers (e.g., NemotronH with Mamba/Attention): + return [BlockConfig(...) for layer_idx in range(num_layers)] + """ + raise NotImplementedError + + @staticmethod + def convert_weight_name(name: str) -> str: + """ + Convert weight names during checkpoint conversion. + + This method can be overridden by subclasses to apply model-specific weight name + transformations when converting checkpoints from HuggingFace format to Puzzletron format. + + Default implementation returns the name unchanged (identity function). + + Args: + name: Original weight name from HuggingFace checkpoint + + Returns: + Converted weight name for Puzzletron format + + Example: + For Qwen2.5-VL, this converts: + - visual.* → model.visual.* + - model.* → model.language_model.* + """ + return name diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py new file mode 100644 index 0000000000..88d490d653 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +from typing import Callable, Type + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor + +__all__ = ["ConverterFactory"] + + +class ConverterFactory: + """Factory for registering and retrieving Converter classes.""" + + CLASS_MAPPING = {} + + @classmethod + def register(cls, **entries: Type): + """Register converter classes. + + Raises: + KeyError: if entry key is already in type_dict and points to a different class. + """ + for cls_name, cls_type in entries.items(): + if cls_name in cls.CLASS_MAPPING: + ref = cls.CLASS_MAPPING[cls_name] + # If ref and cls_name point to the same class ignore and don't raise an exception. + if cls_type == ref: + continue + raise KeyError( + f"Could not register `{cls_name}`: {cls_type}, " + f"`{cls_name}` is already registered and points to " + f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`" + ) + cls.CLASS_MAPPING[cls_name] = cls_type + + @classmethod + def register_decorator(cls, name: str | None) -> Callable: + """Set up a register decorator. + + Args: + name: If specified, the decorated object will be registered with this name. + + Returns: + Decorator that registers the callable. + """ + + def decorator(cls_type: Type) -> Callable: + """Register the decorated callable.""" + cls_name = name if name is not None else cls_type.__name__ + cls.register(**{cls_name: cls_type}) + return cls_type + + return decorator + + @classmethod + def get(cls, value: str | ModelDescriptor): + """Get a registered converter by name or return the converter if already resolved.""" + if isinstance(value, str): + if value in cls.CLASS_MAPPING: + return cls.CLASS_MAPPING[value] + return value diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py new file mode 100644 index 0000000000..cc8e89e34b --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model descriptors for defining model-specific properties and layer naming conventions.""" + +from .model_descriptor import * +from .model_descriptor_factory import * diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py new file mode 100644 index 0000000000..73d56d2016 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Type + +import torch.nn as nn + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + +__all__ = ["ModelDescriptor"] + + +class ModelDescriptor(ABC): + @staticmethod + @abstractmethod + def decoder_layer_cls() -> Type[nn.Module] | List[Type[nn.Module]]: + """Decoder layer class types to patch for heterogeneous config support. + + In most cases this class will hold as attributes both FFN & attention layers. + + Returns: + nn.Module class type or a list if several class types should be patched. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]: + """Map between BlockConfig and layer config overrides. + + These overrides are consumed by a specific decoder layer and by the whole model. + Usage can be seen in `deci_x_patcher` under the method `_patched_decoder_layer_init`. + + Example implementation to override the FFN intermediate size of a block: + >>> def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]: + >>> return {"intermediate_size": block_config.ffn.intermediate_size} + """ + raise NotImplementedError + + @staticmethod + def mlp_no_op_post_init(decoder_layer: nn.Module): + """Post-init callback to alter a decoder layer so that FFN/mlp subblock performs as no-op. + + It is recommended to use the utils modules from `no_op.py` to replace layers to dummy + counterparts. + + Example for replacing a layernorm layer with identity: + + >>> decoder_layer.post_attention_layernorm = Same() + + Example for replacing an MLP layer with zeroes (zeroes since hidden_states are added to + the residuals hidden_states so a no-op implementation will leave residual the same): + + >>> decoder_layer.mlp = MatchingZeros() + + In case the MLP layer to replace returns multiple outputs i.e `hidden_states, _ = self.mlp()`, + use the util method `return_tuple_of_size` to return trailing None values: + + >>> decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() + """ + raise NotImplementedError + + @staticmethod + def attn_no_op_post_init(decoder_layer: nn.Module): + """Post-init callback to alter a decoder layer so that Attention subblock performs as no-op. + + It is recommended to use the utils modules from `no_op.py` to replace layers to dummy + counterparts. + + Example for replacing a layernorm layer with identity: + + >>> decoder_layer.post_attention_layernorm = Same() + + Example for replacing an attention layer with zeroes: + + >>> decoder_layer.self_attn = MatchingZeros() + + In case the attention layer returns multiple outputs i.e `hidden_states, _ = self.self_attn()`, + use the util method `return_tuple_of_size` to return trailing None values: + + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def init_rotary_embedding(model, runtime): + """Re-initiate the rotary embeddings based on an existing model. + + In puzzletron we initiate a sharded model by first creating a meta model then replacing + to the actual device by loading the state_dict with the real weights. + + Rotary embeddings frequencies are tensor buffers that are created dynamically during init + and are not part of the model state_dict, so cannot be restored after a meta device + initialization. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def input_embedding_name(): + """Return the name of the input embedding layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def output_embedding_name(): + """Return the name of the output embedding layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def final_norm_name(): + """Return the name of the final normalization layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def layer_block_name(index: int): + """Return the name of the decoder layer at the given index.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + """Return predicates for grouping model weights to support subblock checkpointing. + + For every group name return a regex predicate whether a layer name is part of the group. + + Returns: + Dictionary of group name to regex pattern predicate. + """ + raise NotImplementedError + + @staticmethod + def uses_autocast() -> bool: + """Whether this model supports torch.autocast. + + Some models (e.g., Qwen3-VL MoE) have dtype bugs under autocast. + Override and return False for models that do not support autocast. + """ + return True + + @staticmethod + def get_language_model_config(config): + """Get the language model config from a PretrainedConfig. + + For regular LM models, returns the config itself. + For VL/multimodal models with nested configs, override to return the + language model portion (e.g., config.text_config for Qwen-VL). + """ + return config + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block to replace a layer for sharded model initialization.""" + return DummyBlock(block_index=block_index) + + @classmethod + def mlp_no_op_supported(cls) -> bool: + """Check whether `mlp_no_op_post_init` is overridden for mlp no-op support.""" + method_name = ModelDescriptor.mlp_no_op_post_init.__name__ + return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name) + + @classmethod + def attn_no_op_supported(cls): + """Check whether `attn_no_op_post_init` is overridden for attention no-op support.""" + method_name = ModelDescriptor.attn_no_op_post_init.__name__ + return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name) + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """Group model weights to support the puzzle subblock checkpointing format. + + This method uses the abstract method `layer_name_predicates` by default. + + Args: + layer_names: state_dict layer names of the model. + num_hidden_layers: number of decoder layers in the model. + + Returns: + Dictionary of group names to list of layer names per group, e.g.: + >>> { + ... "embedding": ["model.embed_tokens.weight"], + ... "lm_head": ["lm_head.weight", "model.norm.weight"], + ... "block_0_ffn": ["model.layers.0.mlp.down_proj", ...], + ... "block_0_attention": ["model.layers.0.self_attn.q_proj", ...], + ... } + """ + weight_groups = defaultdict(list) + for name in layer_names: + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + break + else: + raise ValueError(f"Couldn't find a match for {name}") + return weight_groups diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py new file mode 100644 index 0000000000..badbe2b0e3 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +from typing import Callable, Type + +from transformers import AutoConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor + +__all__ = ["ModelDescriptorFactory"] + +# Map from HuggingFace config.model_type (in checkpoint config.json) to ModelDescriptorFactory name. +# Local to this script; add entries when supporting new model types for auto-detection. +_MODEL_TYPE_TO_DESCRIPTOR = { + "llama": "llama", + "mistral": "mistral_small", + "qwen2": "qwen2", + "qwen3": "qwen3", + "nemotron_h": "nemotron_h", + "nemotron_h_v2": "nemotron_h_v2", + "gpt_oss_20b": "gpt_oss_20b", +} + + +def resolve_descriptor_from_pretrained(pretrained: str, trust_remote_code: bool = False): + """Resolve the model descriptor by loading the checkpoint config and mapping model_type. + + Args: + pretrained: Path to a pretrained model checkpoint or HuggingFace model identifier. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + The resolved ModelDescriptor class for the detected model type. + + Raises: + ValueError: If pretrained is not provided or if the model type cannot be auto-detected. + """ + + config = AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code) + model_type = getattr(config, "model_type", None) + + if model_type and model_type in _MODEL_TYPE_TO_DESCRIPTOR: + detected = _MODEL_TYPE_TO_DESCRIPTOR[model_type] + print( + f"[resolve_descriptor_from_pretrained] Auto-detected model_type='{model_type}' → descriptor='{detected}'" + ) + return ModelDescriptorFactory.get(detected) + + known = sorted(_MODEL_TYPE_TO_DESCRIPTOR.keys()) + raise ValueError( + f"Cannot auto-detect descriptor for model_type='{model_type}'. " + f"Known model types: {known}. Add this model_type to _MODEL_TYPE_TO_DESCRIPTOR if supported." + ) + + +class ModelDescriptorFactory: + """Factory for registering and retrieving ModelDescriptor classes.""" + + CLASS_MAPPING = {} + + @classmethod + def register(cls, **entries: Type): + """Register model descriptor classes. + + Raises: + KeyError: if entry key is already in type_dict and points to a different class. + """ + for cls_name, cls_type in entries.items(): + if cls_name in cls.CLASS_MAPPING: + ref = cls.CLASS_MAPPING[cls_name] + # If ref and cls_name point to the same class ignore and don't raise an exception. + if cls_type == ref: + continue + raise KeyError( + f"Could not register `{cls_name}`: {cls_type}, " + f"`{cls_name}` is already registered and points to " + f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`" + ) + cls.CLASS_MAPPING[cls_name] = cls_type + + @classmethod + def register_decorator(cls, name: str | None) -> Callable: + """Set up a register decorator. + + Args: + name: If specified, the decorated object will be registered with this name. + + Returns: + Decorator that registers the callable. + """ + + def decorator(cls_type: Type) -> Callable: + """Register the decorated callable.""" + cls_name = name if name is not None else cls_type.__name__ + cls.register(**{cls_name: cls_type}) + return cls_type + + return decorator + + @classmethod + def get(cls, value: str | ModelDescriptor): + """Get a registered model descriptor by name or return the descriptor if already resolved.""" + if isinstance(value, str): + if value in cls.CLASS_MAPPING: + return cls.CLASS_MAPPING[value] + return value diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py new file mode 100644 index 0000000000..f2119059f4 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Import models to trigger factory registration +# from modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b import * +from modelopt.torch.puzzletron.anymodel.models.llama import * +# from modelopt.torch.puzzletron.anymodel.models.mistral_small import * +# from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * +# from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import * +# from modelopt.torch.puzzletron.anymodel.models.qwen2 import * +# from modelopt.torch.puzzletron.anymodel.models.qwen3_8b import * +# from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct import * diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py new file mode 100644 index 0000000000..a0be9f919e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.llama.llama_converter import LlamaConverter +from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import ( + LlamaModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py new file mode 100644 index 0000000000..5a0686ecc8 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Llama converter for AnyModel compression.""" + +from typing import List + +from transformers import LlamaConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("llama") +class LlamaConverter(Converter): + """Converter for Llama models to AnyModel format.""" + + @staticmethod + def create_block_configs_from_main_config(config: LlamaConfig) -> List[BlockConfig]: + """Create uniform block configs for all Llama layers. + + Llama models have uniform architecture across all layers, so we create + the same BlockConfig for each layer. + """ + num_hidden_layers = config.num_hidden_layers + + block_configs = [ + BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + for _ in range(num_hidden_layers) + ] + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py new file mode 100644 index 0000000000..fe416e2dd6 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Llama model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaRotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) + + +@ModelDescriptorFactory.register_decorator("llama") +class LlamaModelDescriptor(ModelDescriptor): + """Model descriptor for Llama models (Llama 2, Llama 3, Llama 3.1, Llama 3.2).""" + + @staticmethod + def decoder_layer_cls(): + return LlamaDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: LlamaDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: LlamaDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: LlamaForCausalLM, runtime): + model.model.rotary_emb = LlamaRotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class LlamaFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + """Layer descriptor for Llama FFN intermediate pruning.""" + + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py new file mode 100644 index 0000000000..3af98d57fe --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for patching and transforming HuggingFace models to work with AnyModel. + +Provides no-op modules for layer replacement and patching utilities for heterogeneous +per-layer configurations. +""" + +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.utils import ( + deci_x_patcher, + override_config_with_block_configs, +) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py new file mode 100644 index 0000000000..aac57af0a9 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""No-op modules for replacing layers during pruning.""" + +from functools import cache + +import torch +import torch.nn as nn + + +@cache +def return_tuple_of_size(cls: type[nn.Module], size: int) -> type[nn.Module]: + """Create a wrapper class that returns a tuple of the given size. + + Useful for replacing modules that return multiple outputs (e.g., attention layers + that return (hidden_states, attn_weights)). + + Args: + cls: The base module class to wrap. + size: The size of the tuple to return. + + Returns: + A new class that wraps the base class and returns a tuple of the given size. + + Example: + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + """ + + class Wrapped(cls): + def forward(self, *args, **kwargs): + result = super().forward(*args, **kwargs) + outputs = [None] * size + outputs[0] = result[0] + return tuple(outputs) + + def extra_repr(self) -> str: + return f"[{cls.__name__}]" + + return Wrapped + + +class MatchingZeros(nn.Module): + """Module that returns zeros matching the input shape. + + Used to replace MLP or attention layers with no-ops. Returns zeros because + the hidden_states are added to the residuals, so a no-op implementation + should leave the residual unchanged. + """ + + def forward(self, hidden_states, *args, **kwargs): + return torch.zeros_like(hidden_states) + + +class Same(nn.Module): + """Module that returns the input unchanged. + + Used to replace normalization layers with identity operations. + """ + + def forward(self, hidden_states, *args, **kwargs): + return hidden_states + + @property + def weight(self): + """Support NemotronH with scoring_activations, when lm_head is called `self.lm_head.weight.dtype`.""" + return torch.empty(0) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py b/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py new file mode 100644 index 0000000000..93913b8e2b --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import copy +import inspect +from contextlib import ExitStack, contextmanager +from functools import wraps +from typing import Any, Dict, List + +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + BlockConfig, + maybe_cast_block_configs, +) + + +def _get_variable_from_stack(names: list[str]) -> Any: + """Search the call stack for a variable with one of the given names.""" + f = inspect.currentframe().f_back + while f: + for name in names: + if name in f.f_locals: + return f.f_locals[name] + f = f.f_back + raise RuntimeError(f"{names} not found in caller stack") + + +@contextmanager +def deci_x_patcher( + model_descriptor: ModelDescriptor, + block_configs: List[BlockConfig | dict] | None = None, +): + """Context manager that patches decoder layer __init__ for heterogeneous per-layer configs. + + This is the core mechanism that enables AnyModel to work with any HuggingFace model. + It patches the decoder layer class(es) to read per-layer block_configs and apply + layer-specific overrides (e.g., different intermediate_size per layer). + + Args: + model_descriptor: The model descriptor that defines which classes to patch + and how to map block_configs to layer overrides. + block_configs: Optional list of BlockConfig (one per layer). If not provided, + will try to read from config.block_configs during model initialization. + + Example: + >>> with deci_x_patcher(LlamaModelDescriptor, block_configs): + ... model = AutoModelForCausalLM.from_config(config) + """ + decoder_layer_classes = model_descriptor.decoder_layer_cls() # Now a list of classes + if not isinstance(decoder_layer_classes, list): + decoder_layer_classes = [decoder_layer_classes] + + orig_inits = [] + for cls in decoder_layer_classes: + orig_inits.append(cls.__init__) + + block_configs = maybe_cast_block_configs(block_configs) + + @wraps(orig_inits[0]) + def _patched_decoder_layer_init(self, config, *args, **kwargs): + _block_configs = block_configs or getattr(config, "block_configs", None) + if _block_configs is None: + return orig_inits[decoder_layer_classes.index(self.__class__)]( + self, config, *args, **kwargs + ) + + _block_configs = maybe_cast_block_configs(_block_configs) + layer_idx = _get_variable_from_stack(["layer_idx", "idx"]) + _block_config = _block_configs[layer_idx] + override_block_config = model_descriptor.block_config_to_layer_overrides(_block_config) + _config = override_config_with_block_configs(config, override_block_config) + orig_inits[decoder_layer_classes.index(self.__class__)](self, _config, *args, **kwargs) + + # Apply no-op post-init + if _block_config.attention.no_op: + if not model_descriptor.attn_no_op_supported(): + raise NotImplementedError( + f"attn no-op not supported for `{model_descriptor.__class__.__name__}`, " + "please implement the method: `attn_no_op_post_init()`" + ) + model_descriptor.attn_no_op_post_init(decoder_layer=self) + + if _block_config.ffn.no_op: + if not model_descriptor.mlp_no_op_supported(): + raise NotImplementedError( + f"mlp no-op not supported for `{model_descriptor.__class__.__name__}`, " + "please implement the method: `mlp_no_op_post_init()`" + ) + model_descriptor.mlp_no_op_post_init(decoder_layer=self) + + with ExitStack() as stack: + # Patch every decoder layer class + for orig_init, cls in zip(orig_inits, decoder_layer_classes): + stack.callback(setattr, cls, "__init__", orig_init) # Restore on exit + cls.__init__ = _patched_decoder_layer_init + yield + + +def override_config_with_block_configs( + config: PretrainedConfig, block_configs: Dict[str, Any] +) -> PretrainedConfig: + """Create a copy of config with block_config overrides applied.""" + _config = copy.deepcopy(config) + # Model initialization requires fails with None in case of no-ops + _config_overrides = {k: v for k, v in block_configs.items() if v is not None} + _config.update(_config_overrides) + return _config diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py index d5eebfa352..a7212516a7 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py @@ -19,7 +19,7 @@ import warnings from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Type, Union, get_args, get_origin +from typing import Any, List, Optional, Type, Union, get_args, get_origin @dataclass(frozen=True, kw_only=True) @@ -178,106 +178,51 @@ class Llama4AttentionConfig(BaseDataclass): @dataclass(frozen=True, kw_only=True) class AttentionConfig(SubblockConfig): - n_heads_in_group: Optional[int] = None - window_length: Optional[int] = None - num_sink_tokens: Optional[int] = None - use_prefill_window_in_sink_attention: bool = False - unshifted_sink: bool = False - mamba: Optional[MambaConfig] = None + num_key_value_heads: Optional[int] = None llama4: Optional[Llama4AttentionConfig] = None + mamba: Optional[MambaConfig] = None def __post_init__(self): super().__post_init__() if self.no_op: - assert not self.replace_with_linear assert not self.is_mamba assert not self.is_llama4 - if self.no_op or self.replace_with_linear or self.is_mamba: + if self.no_op or self.is_mamba: for irrelevant_att in [ - "n_heads_in_group", - "window_length", - "num_sink_tokens", - "use_prefill_window_in_sink_attention", - "unshifted_sink", - "attention_chunk_size", - "attn_scale", - "floor_scale", - "attn_temperature_tuning", - "attention_dropout", - "use_qk_norm", + "num_key_value_heads", ]: self._force_setattr(irrelevant_att, None) else: - assert self.n_heads_in_group is not None - - if self.is_sink: - assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), ( - "Unshifted sink uses its own kind of explicit masking, not standard window. " - "Set use_prefill_window_in_sink_attention to False." - ) - assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), ( - "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True" - ) - - if self.is_llama4: - assert not self.is_sink, "Sink not support with Llama4 currently" - assert not self.is_sliding, "Sliding window not support with Llama4 currently" - assert not self.unshifted_sink, "Unshifted sink not support with Llama4 currently" + assert self.num_key_value_heads is not None def to_blockconfig(self) -> "BlockConfig": return BlockConfig(attention=self, ffn=FFNConfig(no_op=True)) @property - def prefill_sliding_window(self) -> Optional[int]: - if self.window_length is not None: - if not self.is_sink or self.use_prefill_window_in_sink_attention: - return self.window_length - return None - - @property - def is_sliding(self) -> bool: - return self.prefill_sliding_window is not None - - @property - def is_sink(self) -> bool: - return (self.window_length is not None) and (self.num_sink_tokens is not None) + def is_llama4(self) -> bool: + return self.llama4 is not None @property def is_mamba(self) -> bool: return self.mamba is not None - @property - def is_llama4(self) -> bool: - return self.llama4 is not None - @dataclass(frozen=True, kw_only=True) class FFNConfig(SubblockConfig): - gated: Optional[bool] = ( - True # Gated Linear Unit e.g. SwiGLU or vanilla MLP (up -> activation -> down) - ) - hidden_act: Optional[str] = "silu" moe: Optional[MoEConfig] = None intermediate_size: Optional[int] = None def __post_init__(self): super().__post_init__() - if self.no_op or self.replace_with_linear: - self._force_setattr("gated", None) - self._force_setattr("hidden_act", None) + if self.no_op: self._force_setattr("moe", None) self._force_setattr("intermediate_size", None) elif self.is_moe: - self._force_setattr("gated", None) - self._force_setattr("hidden_act", None) self._force_setattr("intermediate_size", None) else: - assert self.intermediate_size is not None, ( - "Intermediate size must be provided for an FFN block" - ) - assert self.intermediate_size % 256 == 0, "Intermediate size must be divisible by 256" + assert self.intermediate_size is not None, "Intermediate size must be provided for an FFN block" def to_blockconfig(self) -> "BlockConfig": return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self) @@ -306,3 +251,25 @@ def __post_init__(self): BlockConfig(**block_config) for block_config in self.parallel_blocks ] self._force_setattr("parallel_blocks", initialized_block_configs) + + def to_dict(self) -> dict: + """Convert BlockConfig to a dictionary.""" + return dataclasses.asdict(self) + + +def maybe_cast_block_configs( + block_configs: List[BlockConfig | dict] | None, +) -> List[BlockConfig] | None: + """Cast a list of dicts to BlockConfig objects if needed. + + Args: + block_configs: List of BlockConfig or dict objects, or None. + + Returns: + List of BlockConfig objects, or None if input is None/empty. + """ + if not block_configs: + return block_configs + if isinstance(block_configs[0], dict): + return [BlockConfig(**conf) for conf in block_configs] + return block_configs diff --git a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py index 5e1eace934..e5025dea7d 100644 --- a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py @@ -13,14 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Puzzletron NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). +""" +Puzzletron NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). -It is used by mtn.convert() to convert a model from HF format to DeciLM format + do pruning scoring +It is used by mtn.convert() to convert a model from HF format to Puzzletron heterogeneous format + do pruning scoring and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search. """ +import datetime from pathlib import Path +import hydra +import torch from torch import nn import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models @@ -39,9 +43,8 @@ from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict from modelopt.torch.puzzletron import build_library_and_stats from modelopt.torch.puzzletron.activation_scoring import score_pruning_activations -from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( - convert_llama3_to_decilm, -) +from modelopt.torch.puzzletron.anymodel.converter import ConverterFactory +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir from modelopt.torch.puzzletron.tools.logger import mprint @@ -90,7 +93,7 @@ class PuzzletronConfig(ModeloptBaseConfig): def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: - """1. Convert the model from HF format to DeciLM format. + """1. Convert the model from HF format to AnyModel format. 2. Score the pruning activations. 3. Prune the model and save pruned checkpoints @@ -111,14 +114,24 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv f"dataset_path={config.dataset_path}", ], ) + # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Convert Llama3 model to DeciLM model - # TODO: Make it generic, do not call convert_llama3_to_decilm directly. + # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config) if dist.is_master(): - mprint("Puzzletron Progress 2/8: converting model from HF to DeciLM (single-gpu)") + mprint( + "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)" + ) hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable - convert_llama3_to_decilm( - input_dir=config.input_model_path, + + # Get descriptor and converter from the hydra config + descriptor_name = hydra_cfg.descriptor + descriptor = ModelDescriptorFactory.get(descriptor_name) + converter = ConverterFactory.get(descriptor_name) + + converter.convert( + descriptor=descriptor, + input_dir=Path(config.input_model_path), output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, ) dist.barrier() @@ -162,6 +175,7 @@ def config_class(self) -> type[ModeloptBaseConfig]: @property def search_algorithm(self) -> type[BaseSearcher]: """Return the associated searcher implementation.""" + return PuzzletronSearcher @property @@ -201,6 +215,8 @@ def run_search(self) -> None: f"dataset_path={self.model.dataset_path}", ], ) + # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) # Build_library_and_stats (single process) if dist.is_master(): diff --git a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py new file mode 100644 index 0000000000..96d3489f5e --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( + ForwardHook, + GptOssRemoveExpertsIndependentHook, + NemotronHRemoveExpertsIndependentHook, + Qwen3VLRemoveExpertsIndependentHook, + RankedChoiceVotingHook, + RankedChoiceVotingHookNemotronH, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import MlpInitMode, _init_moe_module + + +@dataclass +class ExpertRemovalLayerDescriptor(LayerDescriptor): + """ + TODO - Add Shared expert weights in case it's prunable. + TODO - consider removing the segmentation between weight and bias, doesn't seem to affect the pruning algo. + Attributes: + target_name: module name required to register hooks for scoring_activations, can be a regex if start with the prefix `regex:` + moe_prefix_name: moe prefix layer name, should include a placeholder for `layer_idx` to be repeated for all layers. i.e: `model.layers.{layer_idx}.moe` + expert_prefix_name: expert prefix layer name relative to moe_prefix, should include a placeholder for `expert_idx` to be repeated for all experts. i.e: `experts.{expert_idx}` + router_weights: List of the router weight names relative to moe_prefix. + router_biases: List of the router bias names relative to moe_prefix. + expert_weights: List of the expert weight names relative to expert_prefix (for per-expert format). + expert_biases: List of the expert bias names relative to expert_prefix (for per-expert format). + is_fused_experts: If True, experts are stored as single fused tensors with shape [num_experts, ...]. + If False (default), experts are stored as separate tensors per expert. + fused_expert_weights: List of fused expert weight names relative to moe_prefix (for fused format). + e.g., ["experts.gate_up_proj", "experts.down_proj"] + """ + + target_name: str + moe_prefix_name: str + expert_prefix_name: str = "" + router_weights: List[str] = field(default_factory=list) + router_biases: List[str] = field(default_factory=list) + expert_weights: List[str] = field(default_factory=list) + expert_biases: List[str] = field(default_factory=list) + is_fused_experts: bool = False + fused_expert_weights: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.target_name + + def moe_prefix(self, layer_idx: int) -> str: + return self.moe_prefix_name.format(layer_idx=layer_idx) + + def expert_prefix(self, layer_idx: int, expert_idx: int) -> str: + _expert_prefix = self.moe_prefix_name + "." + self.expert_prefix_name + return _expert_prefix.format(layer_idx=layer_idx, expert_idx=expert_idx) + + +class ExpertRemovalPruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: ExpertRemovalLayerDescriptor): + assert isinstance(layer_descriptor, ExpertRemovalLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [ + RankedChoiceVotingHook, + RankedChoiceVotingHookNemotronH, + NemotronHRemoveExpertsIndependentHook, + Qwen3VLRemoveExpertsIndependentHook, + GptOssRemoveExpertsIndependentHook, + ] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + keys: dict, + **kwargs, + ) -> Dict[str, torch.Tensor]: + layer_out_state_dict = {} + + child_block_config = new_config.block_configs[layer_idx] + parent_block_config = original_config.block_configs[layer_idx] + + if not parent_block_config.ffn.is_moe: + return layer_out_state_dict + + new_num_experts = child_block_config.ffn.moe.num_local_experts + orig_num_experts = parent_block_config.ffn.moe.num_local_experts + + child_router_keys, new_experts_keys = self._generate_moe_keys(layer_idx, new_num_experts) + parent_router_keys, orig_experts_keys = self._generate_moe_keys(layer_idx, orig_num_experts) + + # Pop parent's router keys from copy list; child-only router keys will be initialized below + for rk in sum(parent_router_keys.values(), []): + if rk in keys: + keys.pop(rk) + for key in sum(orig_experts_keys.values(), []): + if key in keys: + keys.pop(key) + + if self.layer_descriptor.is_fused_experts: + # Fused format: unbundle single tensor [num_experts, ...] into list of per-expert tensors + orig_experts_weights = {} + for name, fused_keys in orig_experts_keys.items(): + fused_tensor = parent_state_dict[fused_keys[0]] # Single fused tensor + orig_experts_weights[name] = [fused_tensor[i] for i in range(orig_num_experts)] + + new_experts_weights = {} + for name, fused_keys in new_experts_keys.items(): + fused_tensor = new_state_dict[fused_keys[0]] # Single fused tensor + new_experts_weights[name] = [fused_tensor[i] for i in range(new_num_experts)] + else: + # Per-expert format: load each expert tensor separately + orig_experts_weights = { + name: [parent_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + new_experts_weights = { + name: [new_state_dict[key] for key in new_experts_module_keys] + for name, new_experts_module_keys in new_experts_keys.items() + } + + orig_router_weights = { + name: [parent_state_dict[key] for key in _module_router_keys] + for name, _module_router_keys in parent_router_keys.items() + } + new_router_weights = { + name: [new_state_dict[key] for key in _module_router_keys] + for name, _module_router_keys in child_router_keys.items() + } + + out_router_weights, out_experts_weights = _init_moe_module( + layer_idx=layer_idx, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + orig_router_weights=orig_router_weights, + orig_experts_weights=orig_experts_weights, + new_router_weights=new_router_weights, + new_experts_weights=new_experts_weights, + orig_num_experts=orig_num_experts, + new_num_experts=new_num_experts, + ) + assert new_experts_keys.keys() == out_experts_weights.keys(), ( + "new_experts_keys and out_experts_weights must have the same keys" + ) + assert child_router_keys.keys() == out_router_weights.keys(), ( + "child_router_keys and out_router_weights must have the same keys" + ) + + for name in child_router_keys.keys(): + layer_out_state_dict.update(zip(child_router_keys[name], out_router_weights[name])) + + if self.layer_descriptor.is_fused_experts: + # Fused format: rebundle list of per-expert tensors into single fused tensor + for name in new_experts_keys.keys(): + fused_key = new_experts_keys[name][0] # Single key for fused tensor + fused_tensor = torch.stack(out_experts_weights[name], dim=0) # [num_experts, ...] + layer_out_state_dict[fused_key] = fused_tensor + else: + # Per-expert format: each expert has its own key + for name in new_experts_keys.keys(): + layer_out_state_dict.update(zip(new_experts_keys[name], out_experts_weights[name])) + + return layer_out_state_dict + + def _generate_moe_keys( + self, layer_idx: int, num_experts: int + ) -> Tuple[Dict[str, List[str]], dict[str, list[str]]]: + """ + Generate MoE weight keys for router and experts. + TODO simplify or better define the data structure of the moe keys returned. + + :return: tuple of router_keys and expert_keys, all are absolute names relative to the model root: + * router_keys structure: + {"weight: [], bias: []"} + * expert_keys structure (per-expert format): + {": []} + i.e: + { + "down_proj.weight": ["model...experts.0.down_proj.weight", ..., "model...experts.N.down_proj.weight"], + ... + } + * expert_keys structure (fused format): + {": []} + i.e: + { + "experts.gate_up_proj": ["model...experts.gate_up_proj"], + "experts.down_proj": ["model...experts.down_proj"], + } + """ + self.layer_descriptor: ExpertRemovalLayerDescriptor + moe_prefix = self.layer_descriptor.moe_prefix(layer_idx) + + router_keys = { + "weight": [ + f"{moe_prefix}.{_weight}" for _weight in self.layer_descriptor.router_weights + ], + "bias": [f"{moe_prefix}.{_bias}" for _bias in self.layer_descriptor.router_biases], + } + + if self.layer_descriptor.is_fused_experts: + # Fused format: single tensor per weight type with shape [num_experts, ...] + experts_module_names = {} + for fused_weight in self.layer_descriptor.fused_expert_weights: + experts_module_names[fused_weight] = [f"{moe_prefix}.{fused_weight}"] + else: + # Per-expert format: separate tensor for each expert + expert_key_names = ( + self.layer_descriptor.expert_weights + self.layer_descriptor.expert_biases + ) + experts_module_names = {} + for key_name in expert_key_names: + experts_module_names[key_name] = [ + f"{self.layer_descriptor.expert_prefix(layer_idx, expert_idx)}.{key_name}" + for expert_idx in range(num_experts) + ] + + return router_keys, experts_module_names diff --git a/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py new file mode 100644 index 0000000000..b3d9b88847 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Type + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( + ForwardHook, + IndependentChannelContributionHook, + IterativeChannelContributionHook, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + MlpInitMode, + _init_mlp_module, +) + + +@dataclass +class FFNIntermediateLayerDescriptor(LayerDescriptor): + down_proj_name: str + ffn_prefix_name: str + linear_weight_names: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.down_proj_name + + def ffn_prefix(self, layer_idx: int) -> str: + return self.ffn_prefix_name.format(layer_idx=layer_idx) + + +class FFNIntermediatePruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: FFNIntermediateLayerDescriptor): + assert isinstance(layer_descriptor, FFNIntermediateLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [IndependentChannelContributionHook, IterativeChannelContributionHook] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + keys: dict, + keys_to_remove: dict, + **kwargs, + ) -> Dict[str, torch.Tensor]: + layer_out_state_dict = {} + # Hardcoded strings + mlp_prefix = self.layer_descriptor.ffn_prefix(layer_idx) + mlp_key_names = [ + f"{mlp_prefix}.{name}.weight" for name in self.layer_descriptor.linear_weight_names + ] + mlp_keys = [keys.get(module_name) for module_name in mlp_key_names] + mlp_keys = [k for k in mlp_keys if k is not None] + + for key in mlp_keys: + keys_to_remove[f"{mlp_prefix}.{key.split('.')[-2]}.weight"] = key + + pruned_filters = None + projection_matrix = None + + for mlp_key in mlp_keys: + expanded_dim = 1 if self.layer_descriptor.down_proj_name in mlp_key else 0 + if mlp_key in new_state_dict.keys(): + mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( + mlp_init_mode, + mlp_prefix, + expanded_dim, + layer_idx, + new_state_dict[mlp_key], + new_config, + parent_state_dict[mlp_key], + original_config, + mlp_init_config, + pruned_filters, + projection_matrix, + ) + layer_out_state_dict[mlp_key] = mlp_module_weight + + return layer_out_state_dict diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py new file mode 100644 index 0000000000..f93e4b77ab --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from dataclasses import dataclass, field +from typing import Any, List, Optional, Type + +from transformers import PretrainedConfig + +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( + ForwardHook, + IndependentKvHeadContributionHook, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + GQAInitMode, + _init_attention_biases, + _init_attention_weights, +) + + +@dataclass +class KVHeadsLayerDescriptor(LayerDescriptor): + o_proj_name: str + attn_prefix_name: str + qkvo_weight_names: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.o_proj_name + + def attn_prefix(self, layer_idx: int) -> str: + return self.attn_prefix_name.format(layer_idx=layer_idx) + + +class KVHeadsPruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: KVHeadsLayerDescriptor): + assert isinstance(layer_descriptor, KVHeadsLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [IndependentKvHeadContributionHook] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + gqa_init_mode: GQAInitMode, + mlp_init_config: Optional[dict[str, Any]], + is_original_mha: bool, + keys: dict, + keys_to_remove: dict, + **kwargs, + ): + layer_out_state_dict = {} + + attn_prefix = self.layer_descriptor.attn_prefix(layer_idx) + q_name, k_name, v_name, o_name = [ + f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names + ] + + head_size = new_config.head_dim + for part in ["weight", "bias"]: + attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]] + q_key, k_key, v_key, o_key = attn_keys + + # Drop attn keys that don't exist and required to be in the new state_dict + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] + if len(attn_keys) > 0 and all(key in keys for key in attn_keys): + for key in attn_keys: + keys_to_remove[key] = keys[key] + is_student_and_teacher_have_same_attention_implementation = all( + key in new_state_dict.keys() for key in attn_keys + ) + if is_student_and_teacher_have_same_attention_implementation: + if part == "weight": + wq, wk, wv, wo = _init_attention_weights( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk + layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo + else: + bias_sd = _init_attention_biases( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]): + if bias_key in bias_sd.keys(): + layer_out_state_dict[sd_key] = bias_sd[bias_key] + + return layer_out_state_dict diff --git a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py index 5a0dfed01d..823f42faf8 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py +++ b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py @@ -23,14 +23,22 @@ import json import os import time +from typing import Optional from omegaconf import DictConfig -from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ExpertRemovalPruningMixIn +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediatePruningMixIn, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsPruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import ( GQAInitMode, HiddenSizeInitMode, LinearInitMode, MlpInitMode, + resolve_pruning_mixin, ) from modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent import ( init_child_from_parent, @@ -40,7 +48,7 @@ def launch_ffn_intermediates_prune_ckpt( - cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None ): for intermediate_size in cfg.pruning.intermediate_size_list: dirname = f"ffn_{intermediate_size}_attn_no_op" @@ -54,14 +62,16 @@ def launch_ffn_intermediates_prune_ckpt( model_config_overrides_json = {"ffn": [{"intermediate_size": intermediate_size}]} mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) # Profile the overall init_child_from_parent call with optimizations mprint("Starting init_child_from_parent...") start_time = time.time() init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.teacher_dir, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -83,7 +93,7 @@ def launch_ffn_intermediates_prune_ckpt( def launch_attn_groups_prune_ckpt( - cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None ): for n_heads_in_group in cfg.pruning.n_heads_in_group_list: dirname = f"n_heads_in_group{n_heads_in_group}" @@ -98,14 +108,16 @@ def launch_attn_groups_prune_ckpt( model_config_overrides_json = {"attention": [{"n_heads_in_group": n_heads_in_group}]} mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) # Profile the overall init_child_from_parent call with optimizations mprint("Starting init_child_from_parent...") start_time = time.time() init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.teacher_dir, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -150,17 +162,17 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): else: intermediate_sizes.append(None) - mprint("Teacher config:") + mprint(f"Teacher config:") mprint(f" - hidden_size: {parent_hidden_size}") mprint(f" - intermediate_sizes: {intermediate_sizes}") os.makedirs(os.path.join(cfg.puzzle_dir, "ckpts"), exist_ok=True) for hidden_size in cfg.pruning.hidden_size_list: - mprint("\n######################################################################") + mprint(f"\n######################################################################") mprint(f"Hidden Size = {hidden_size}") - mprint("######################################################################\n") + mprint(f"######################################################################\n") - mprint("Child config:") + mprint(f"Child config:") mprint(f" - hidden_size: {hidden_size}") # Create model config overrides with proper FFN configuration @@ -178,14 +190,16 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml dirname = f"hidden_size_{hidden_size}" - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) mprint(f"Creating checkpoint with hidden_size={hidden_size}") mprint(f"Model config overrides: {model_config_overrides_json}") init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.pruning.model_name_or_path, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -204,9 +218,9 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): def launch_experts_prune_ckpt( cfg: DictConfig, - max_save_workers: int | None = None, - max_layer_workers: int | None = None, - symlink_suffix: str | None = None, + max_save_workers: Optional[int] = None, + max_layer_workers: Optional[int] = None, + symlink_suffix: Optional[str] = None, ): for num_experts in cfg.pruning.num_experts_to_keep_list: dirname = f"num_experts_{num_experts}" @@ -223,14 +237,16 @@ def launch_experts_prune_ckpt( mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) # Profile the overall init_child_from_parent call with optimizations mprint("Starting init_child_from_parent...") start_time = time.time() init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.teacher_dir, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -252,7 +268,7 @@ def launch_experts_prune_ckpt( def launch_moe_ffn_intermediates_prune_ckpt( - cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None ): for intermediate_size in cfg.pruning.intermediate_size_list: dirname = f"moe_ffn_{intermediate_size}_attn_no_op" @@ -269,14 +285,16 @@ def launch_moe_ffn_intermediates_prune_ckpt( } mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) # Profile the overall init_child_from_parent call with optimizations mprint("Starting init_child_from_parent...") start_time = time.time() init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.teacher_dir, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -296,7 +314,11 @@ def launch_moe_ffn_intermediates_prune_ckpt( def launch_prune_ckpt(cfg: DictConfig): - target_layer = cfg.pruning.activation_hooks_kwargs.target_layer + cfg.descriptor = ModelDescriptorFactory.get(cfg.descriptor) + # Resolve pruning_mixin from config (could be string, enum, or PruningMixIn) + cfg.pruning.pruning_mixin = resolve_pruning_mixin(cfg.pruning.pruning_mixin, cfg.descriptor) + pruning_mixin = cfg.pruning.pruning_mixin + # I/O optimization settings - same as FFN pruning max_save_workers = None # Will auto-calculate as min(CPU count, num files) if "PRUNING_SAVE_WORKERS" in os.environ: @@ -307,29 +329,15 @@ def launch_prune_ckpt(cfg: DictConfig): if "PRUNING_LAYER_WORKERS" in os.environ: max_layer_workers = int(os.environ["PRUNING_LAYER_WORKERS"]) - # Log optimization settings (extracted from individual pruning methods) - mprint("Optimization Settings:") - mprint( - f" - I/O workers (max_workers): {'auto-calculate' if max_save_workers is None else max_save_workers}" - ) - mprint( - f" - Layer workers (max_layer_workers): {'auto-calculate' if max_layer_workers is None else max_layer_workers}" - ) - mprint(" (Override with env vars: PRUNING_IO_WORKERS, PRUNING_LAYER_WORKERS)") - - if target_layer == "mlp.down_proj": + if isinstance(pruning_mixin, FFNIntermediatePruningMixIn): launch_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) - elif target_layer == "self_attn.o_proj": + elif isinstance(pruning_mixin, KVHeadsPruningMixIn): launch_attn_groups_prune_ckpt(cfg, max_save_workers, max_layer_workers) - elif target_layer == "layernorm": - launch_hidden_dim_prune_ckpt(cfg) - elif target_layer == "router": - # Check if we should use symlink suffix for chained pruning - symlink_suffix = getattr(cfg.pruning, "symlink_suffix", None) - launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers, symlink_suffix) - elif target_layer == r"regex:experts\.\d+\.down_proj$": - launch_moe_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) + elif isinstance(pruning_mixin, ExpertRemovalPruningMixIn): + launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers) + # elif target_layer == "layernorm": + # launch_hidden_dim_prune_ckpt(cfg) else: raise NotImplementedError( - f"checkpoint pruning is not currently supported for target layer: {target_layer}" + f"checkpoint pruning is not currently supported for pruning mixin: {pruning_mixin.__class__.__name__}" ) diff --git a/modelopt/torch/puzzletron/pruning/pruning_mixin.py b/modelopt/torch/puzzletron/pruning/pruning_mixin.py new file mode 100644 index 0000000000..bcb422c4e6 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_mixin.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import re +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Type + +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook + + +class LayerDescriptor: + def module_name_regex(self) -> str: + return "" + + def block_idx_from_module_name(self, module_name: str) -> Optional[int]: + block_idx_match = re.search(r"\.(\d+)\.", module_name) + if block_idx_match: + return int(block_idx_match.group(1)) + return None + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_layer = self.module_name_regex() + if target_layer.startswith("regex:"): + target_layer_regex = target_layer[len("regex:") :] + pattern = re.compile(target_layer_regex) + match_predicate = lambda module_name: pattern.search(module_name) + else: + match_predicate = lambda module_name: module_name.endswith(target_layer) + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if match_predicate(module_name): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook + + +class PruningMixIn(ABC): + def __init__(self, layer_descriptor: LayerDescriptor): + self.layer_descriptor = layer_descriptor + + def get_module_names_to_hook(self, model) -> List[Tuple[int, str]]: + return self.layer_descriptor.get_modules_names_to_hook(model) + + @abstractmethod + def supported_hooks(self) -> List[Type[ForwardHook]]: + raise NotImplementedError + + # @abstractmethod + # def prune_single_layer( + # self, + # layer_idx: int, + # parent_state_dict: dict, + # new_state_dict: dict, + # original_config: PretrainedConfig, + # new_config: PretrainedConfig, + # **kwargs + # ): + # raise NotImplementedError diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py new file mode 100644 index 0000000000..82ba675c94 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -0,0 +1,652 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import json +import math +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn + + +class GQAInitMode(Enum): + RandomKV = "RandomKV" + AverageKV = "AverageKV" + FirstKV = "FirstKV" + RandomBlock = "RandomBlock" + CopyAsIs = "CopyAsIs" + Degrouping = "Degrouping" + PruneKVHeads = "PruneKVHeads" + + +class MlpInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + CopyAsIs = "CopyAsIs" + PruneByActivationsLog = "PruneByActivationsLog" + ExpertRemoval = "ExpertRemoval" + ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + + +class LinearInitMode(Enum): + Random = "Random" + FromTeacher = "FromTeacher" + + +class HiddenSizeInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + PruneByChannelRanking = "PruneByChannelRanking" + CopyAsIs = "CopyAsIs" + + +def resolve_pruning_mixin( + pruning_mixin, descriptor: Type[ModelDescriptor] +) -> PruningMixIn | List[PruningMixIn]: + """ + Convert pruning_mixin argument to PruningMixIn instance(s). + + Args: + pruning_mixin: Can be a string identifier, PruningMixIn instance, + or a list of any of those types. + descriptor: ModelDescriptor class that provides the pruning_mixins() mapping. + + Returns: + PruningMixIn or List[PruningMixIn] depending on input type. + """ + # Handle list of values recursively + if isinstance(pruning_mixin, list): + return [resolve_pruning_mixin(item, descriptor) for item in pruning_mixin] + + # Handle single value + # If it's already a PruningMixIn, return as is + if isinstance(pruning_mixin, PruningMixIn): + return pruning_mixin + + # Get the pruning mixins mapping from the descriptor + mixins_dict = descriptor.pruning_mixins() + + if isinstance(pruning_mixin, str): + if pruning_mixin not in mixins_dict: + available_methods = list(mixins_dict.keys()) + raise ValueError( + f"Pruning method '{pruning_mixin}' is not supported by {descriptor.__name__}. " + f"Available methods: {available_methods}" + ) + return mixins_dict[pruning_mixin] + + raise ValueError(f"Unsupported pruning_mixin type: {type(pruning_mixin)}") + + +def _init_mlp_module( + mlp_init_mode: Union[MlpInitMode, str], + mlp_prefix: str, + expanded_dim: int, + layer_idx: int, + new_item: torch.Tensor, + new_config: PretrainedConfig, + orig_item: torch.Tensor, + original_config: PretrainedConfig, + mlp_init_config: Optional[dict[str, Any]], + pruned_filters: Optional[torch.Tensor] = None, + projection_matrix: Optional[dict[str, torch.Tensor]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + assert orig_item.ndim == 2, f"{orig_item.ndim=}" + assert new_item.ndim == 2, f"{new_item.ndim=}" + + assert new_config.num_hidden_layers == original_config.num_hidden_layers, ( + f"({new_config.num_hidden_layers=}) != ({original_config.num_hidden_layers=})" + ) + + new_intermediate_size = new_config.block_configs[layer_idx].ffn.intermediate_size + original_intermediate_size = original_config.block_configs[layer_idx].ffn.intermediate_size + + if mlp_init_mode == MlpInitMode.CopyAsIs: + assert new_intermediate_size == original_intermediate_size, ( + f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + elif mlp_init_mode == MlpInitMode.Random: + mlp_module_weight = new_item + + elif new_intermediate_size == original_intermediate_size: + mlp_module_weight = orig_item + + elif mlp_init_mode in ( + MlpInitMode.Truncate, + MlpInitMode.PruneByActivationsLog, + ): + assert original_intermediate_size >= new_intermediate_size, ( + f"({original_intermediate_size=}) < ({new_intermediate_size=}), can't be truncated." + ) + orig_ffn_size = orig_item.shape[expanded_dim] + new_ffn_size = new_item.shape[expanded_dim] + + if mlp_init_mode == MlpInitMode.Truncate: + truncated_weight = torch.narrow( + orig_item, dim=expanded_dim, start=0, length=new_ffn_size + ) + mlp_module_weight = truncated_weight + + elif mlp_init_mode == MlpInitMode.PruneByActivationsLog: + if pruned_filters is None: + filter_importance = _load_activations_log( + mlp_init_config, module_name=f"{mlp_prefix}.down_proj" + ) + filters_sorted_by_importance = torch.argsort(filter_importance, descending=True) + pruned_filters = filters_sorted_by_importance[:new_ffn_size].to(orig_item.device) + + pruned_weight = torch.index_select(orig_item, dim=expanded_dim, index=pruned_filters) + if mlp_init_config.get("scale_pruned_weights", False) and expanded_dim == 1: + pruned_weight = pruned_weight * (orig_ffn_size / new_ffn_size) + mlp_module_weight = pruned_weight + + elif ( + mlp_init_mode == MlpInitMode.ExpertRemoval + ): # the case of mlp layers of maverick. for now we only support copy as is + assert new_intermediate_size == original_intermediate_size, ( + f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + else: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + return mlp_module_weight, pruned_filters, projection_matrix + + +def _load_activations_log(mlp_init_config: dict[str, Any], module_name: str) -> torch.Tensor: + _cache_activations_log(mlp_init_config) + module_log = ACTIVATIONS_LOG[module_name] + filter_importance = module_log["score"] + return filter_importance + + +ACTIVATIONS_LOG = dict() + + +def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None: + if len(ACTIVATIONS_LOG) == 0: + assert "activations_log_dir" in mlp_init_config + activations_log_dir = mlp_init_config["activations_log_dir"] + print(f"Loading activations_log from {activations_log_dir}") + # Only load rank_*.pth files to avoid loading hook_states_*.pth checkpoint files + ACTIVATIONS_LOG.update( + { + module_name: module_log + for p in Path(activations_log_dir).glob("rank_*.pth") + for module_name, module_log in torch.load(p).items() + } + ) + + +def _init_attention_weights( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads + orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads + + # new_w* are typically randomly initialized + new_wq = new_state_dict[q_key] + new_wk = new_state_dict[k_key] + new_wv = new_state_dict[v_key] + new_wo = new_state_dict[o_key] + + # w* are from the parent model + wq = original_state_dict[q_key] + wk = original_state_dict[k_key] + wv = original_state_dict[v_key] + wo = original_state_dict[o_key] + + if "bias" in k_key: + for tensor in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo]: + assert tensor.ndim == 1 + tensor.unsqueeze_(1) + dim1 = wk.shape[1] # this is the hidden_size in case of matrix weights, and 1 in case of biases + + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock): + wk, wv = new_wk, new_wv + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV): + assert orig_num_kv_heads % num_kv_heads == 0, ( + f"({orig_num_kv_heads=}) % ({num_kv_heads=}) != 0" + ) + n_heads_to_aggregate = orig_num_kv_heads // num_kv_heads + + wk = wk.view(-1, n_heads_to_aggregate, head_size, dim1) + wv = wv.view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + wk = wk.mean(dim=1) + wv = wv.mean(dim=1) + else: + wk = wk[:, 0] + wv = wv[:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + assert new_wk.shape == wk.shape, f"({new_wk.shape=}) != ({wk.shape=})" + assert new_wv.shape == wv.shape, f"({new_wv.shape=}) != ({wv.shape=})" + assert new_wq.shape == wq.shape, f"({new_wq.shape=}) != ({wq.shape=})" + assert new_wo.shape == wo.shape, f"({new_wo.shape=}) != ({wo.shape=})" + + elif gqa_init_mode == GQAInitMode.Degrouping: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = num_kv_heads + orig_n_groups = orig_num_kv_heads + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + wk = degroup_w(wk) + wv = degroup_w(wv) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + wk = wk.view(orig_num_kv_heads, head_size, dim1) + wv = wv.view(orig_num_kv_heads, head_size, dim1) + wq = wq.view(orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size, dim1) + wo = wo.view(dim1, orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size) + + o_proj_module_name = o_key.replace(".weight", "") + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + wk = wk[kv_heads_to_keep] + wv = wv[kv_heads_to_keep] + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + wq = wq[kv_heads_to_keep] + wq = torch.repeat_interleave(wq, repeats=reduction_factor, dim=0) + + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + wo = wo[:, kv_heads_to_keep] + wo = torch.repeat_interleave(wo, repeats=reduction_factor, dim=1) + wo = wo / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + wq = wq[kv_head_ordering] + + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + wo = wo[:, kv_head_ordering] + wo[:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + wk = wk.reshape(-1, dim1) + wv = wv.reshape(-1, dim1) + wq = wq.reshape(-1, dim1) + wo = wo.reshape(dim1, -1) + return wq, wk, wv, wo + + +def _init_attention_biases( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads + orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads + n_heads_in_group = num_q_heads // num_kv_heads + orig_n_heads_in_group = num_q_heads // orig_num_kv_heads + + o_proj_bias = new_config.o_proj_bias + attention_bias = new_config.attention_bias + + # If no biases + if not (o_proj_bias or attention_bias): + return {} + + new_bias_sd = {} + bias_sd = {} + # new_w* are typically randomly initialized + if o_proj_bias: + new_bias_sd["o"] = new_state_dict[o_key] + bias_sd["o"] = original_state_dict[o_key] + if attention_bias: + for bias_key, key in zip("qkv", [q_key, k_key, v_key]): + new_bias_sd[bias_key] = new_state_dict[key] + bias_sd[bias_key] = original_state_dict[key] + + # maybe unsqueeze all tensors + for tensor in list(new_bias_sd.values()) + list(bias_sd.values()): + assert tensor.ndim == 1 + tensor.unsqueeze_(1) + + dim1 = 1 # this is the hidden_size in case of matrix weights, and 1 in case of biases + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock) and attention_bias: + bias_sd["k"] = torch.zeros( + new_bias_sd["k"].shape, dtype=bias_sd["k"].dtype, device=bias_sd["k"].device + ) + bias_sd["v"] = torch.zeros( + new_bias_sd["v"].shape, dtype=bias_sd["v"].dtype, device=bias_sd["v"].device + ) + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV) and attention_bias: + assert n_heads_in_group % orig_n_heads_in_group == 0, ( + f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" + ) + n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group + + bias_sd["k"] = bias_sd["k"].view(-1, n_heads_to_aggregate, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + bias_sd["k"] = bias_sd["k"].mean(dim=1) + bias_sd["v"] = bias_sd["v"].mean(dim=1) + else: + bias_sd["k"] = bias_sd["k"][:, 0] + bias_sd["v"] = bias_sd["v"][:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + for key in bias_sd.keys(): + assert new_bias_sd[key].shape == bias_sd[key].shape, ( + f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})" + ) + + elif gqa_init_mode == GQAInitMode.Degrouping and attention_bias: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = new_config.num_attention_heads // n_heads_in_group + orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + bias_sd["k"] = degroup_w(bias_sd["k"]) + bias_sd["v"] = degroup_w(bias_sd["v"]) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + if o_proj_bias: + o_proj_module_name = o_key.rsplit(".", 1)[0] + else: + # Here we assume that the o_proj layer is called "o_proj" + o_proj_module_name = k_key.rsplit(".", 2)[0] + ".o_proj" + + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + # view as KV groups + if attention_bias: + bias_sd["k"] = bias_sd["k"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["q"] = bias_sd["q"].view( + orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1 + ) + # Keep important KV heads and prune the others + bias_sd["k"] = bias_sd["k"][kv_heads_to_keep] + bias_sd["v"] = bias_sd["v"][kv_heads_to_keep] + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].view( + dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size + ) + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + if attention_bias: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + bias_sd["q"] = bias_sd["q"][kv_heads_to_keep] + bias_sd["q"] = torch.repeat_interleave( + bias_sd["q"], repeats=reduction_factor, dim=0 + ) + + if o_proj_bias: + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + bias_sd["o"] = bias_sd["o"][:, kv_heads_to_keep] + bias_sd["o"] = torch.repeat_interleave( + bias_sd["o"], repeats=reduction_factor, dim=1 + ) + bias_sd["o"] = bias_sd["o"] / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + if attention_bias: + bias_sd["q"] = bias_sd["q"][kv_head_ordering] + + if o_proj_bias: + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + bias_sd["o"] = bias_sd["o"][:, kv_head_ordering] + bias_sd["o"][:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + if attention_bias: + for bias_key in "qkv": + bias_sd[bias_key] = bias_sd[bias_key].reshape(-1) + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].reshape(-1) + return bias_sd + + +def _init_moe_module( + mlp_init_mode: Union[MlpInitMode, str], + mlp_init_config: Optional[Dict[str, Any]], + layer_idx: int, + orig_router_weights: Dict[str, List[torch.Tensor]], + orig_experts_weights: Dict[str, List[torch.Tensor]], + new_router_weights: Dict[str, List[torch.Tensor]], + new_experts_weights: Dict[str, List[torch.Tensor]], + orig_num_experts: int, + new_num_experts: int, +) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + + if mlp_init_mode != MlpInitMode.ExpertRemoval: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + selected_experts = _select_expert_indices( + mlp_init_config=mlp_init_config, + layer_idx=layer_idx, + orig_num_experts=orig_num_experts, + new_num_experts=new_num_experts, + ) + + # Router: prefer parent tensors when available; if child has bias only, slice from child + result_router_weights: dict[str, list[torch.Tensor]] = {} + for name, new_list in new_router_weights.items(): + result_router_weights[name] = [ + tensor_to_slice[selected_experts] for tensor_to_slice in orig_router_weights[name] + ] + + # Experts: for each name present in the child, take from parent if available, else from child + result_experts_weights: dict[str, list[torch.Tensor]] = {} + for name, new_list in new_experts_weights.items(): + if name in orig_experts_weights: + src_list = orig_experts_weights[name] + else: + src_list = new_list + result_experts_weights[name] = [src_list[i] for i in selected_experts] + + # Validate shapes + assert result_router_weights.keys() == new_router_weights.keys(), ( + "result_router_weights and new_router_weights must have the same keys" + ) + for name in new_router_weights.keys(): + assert len(new_router_weights[name]) == len(result_router_weights[name]) + for new_router_weight, result_router_weight in zip( + new_router_weights[name], result_router_weights[name] + ): + assert new_router_weight.shape == result_router_weight.shape + + assert result_experts_weights.keys() == new_experts_weights.keys(), ( + "result_experts_weights and new_experts_weights must have the same keys" + ) + for name in result_experts_weights.keys(): + assert len(new_experts_weights[name]) == len(result_experts_weights[name]) + for new_expert_weight, result_expert_weight in zip( + new_experts_weights[name], result_experts_weights[name] + ): + assert new_expert_weight.shape == result_expert_weight.shape + + return result_router_weights, result_experts_weights + + +def _select_expert_indices( + *, mlp_init_config: dict[str, Any], layer_idx: int, orig_num_experts: int, new_num_experts: int +) -> list[int]: + expert_scores = _load_expert_scores(mlp_init_config, layer_idx) + assert len(expert_scores) == orig_num_experts + higher_is_better = mlp_init_config.get("higher_is_better", True) + selected_experts = sorted( + range(orig_num_experts), + key=lambda i: ( + expert_scores[i] + if not math.isnan(expert_scores[i]) + else (float("-inf") if higher_is_better else float("inf")) + ), + reverse=higher_is_better, + )[:new_num_experts] + return selected_experts + + +def _load_expert_scores( + mlp_init_config: Optional[dict[str, Any]], layer_idx: int +) -> list[list[int | float]]: + assert mlp_init_config is not None + if "expert_scores_file" in mlp_init_config: + expert_scores_file = mlp_init_config["expert_scores_file"] + with open(expert_scores_file, "r") as f: + expert_scores = json.load(f) + elif "activations_log_dir" in mlp_init_config: + _cache_activations_log(mlp_init_config) + # Use layer_prefix_template from pruning config, or fall back to legacy nemotron_h format + # TODO - get from descriptors + layer_prefix_template = mlp_init_config.get( + "layer_prefix_template", "backbone.layers.{layer_idx}." + ) + layer_prefix = layer_prefix_template.format(layer_idx=layer_idx) + candidate_layer_keys = [ + key for key in ACTIVATIONS_LOG.keys() if key.startswith(layer_prefix) + ] + if len(candidate_layer_keys) == 0: + raise ValueError(f"No layer keys found for {layer_prefix=}. {ACTIVATIONS_LOG.keys()=}") + elif len(candidate_layer_keys) > 1: + if "layer_suffix" not in mlp_init_config: + raise ValueError( + f"Multiple candidate layer keys found for {layer_prefix=}, you must specify a layer_suffix in the mlp_init_config. {candidate_layer_keys=}" + ) + layer_suffix = mlp_init_config["layer_suffix"] + layer_key = f"{layer_prefix}{layer_suffix}" + else: + layer_key = candidate_layer_keys[0] + layer_log = ACTIVATIONS_LOG[layer_key] + + expert_scores_key = mlp_init_config.get("expert_scores_key", "expert_ranks") + if expert_scores_key not in layer_log: + raise ValueError( + f"Expert scores key {expert_scores_key=} not found in {layer_log.keys()=}" + ) + expert_scores = layer_log[expert_scores_key] + else: + raise ValueError(f"Unsupported {mlp_init_config=}") + return expert_scores diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index 3981b62e34..b30e7eefa9 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -14,7 +14,7 @@ # limitations under the License. # mypy: ignore-errors -"""TODO Add description. Analyze this code, why is it so long and complex? Can it be simplified?""" +"""Core logic for creating pruned child model state dicts from parent models. Used by init_child_from_parent.""" import concurrent.futures import dataclasses @@ -22,12 +22,11 @@ import os import re import time -from collections.abc import Callable from copy import deepcopy from enum import Enum from functools import partial from pathlib import Path -from typing import Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from typeguard import check_type @@ -39,41 +38,23 @@ _is_dataclass_type, ) from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + ACTIVATIONS_LOG, + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + _cache_activations_log, + _init_attention_biases, + _init_attention_weights, + _init_mlp_module, + _init_moe_module, + _load_activations_log, + _load_expert_scores, + _select_expert_indices, +) from modelopt.torch.puzzletron.tools.logger import aprint, mprint - -class GQAInitMode(Enum): - RandomKV = "RandomKV" - AverageKV = "AverageKV" - FirstKV = "FirstKV" - RandomBlock = "RandomBlock" - CopyAsIs = "CopyAsIs" - Degrouping = "Degrouping" - PruneKVHeads = "PruneKVHeads" - - -class MlpInitMode(Enum): - Random = "Random" - Truncate = "Truncate" - CopyAsIs = "CopyAsIs" - PruneByActivationsLog = "PruneByActivationsLog" - ExpertRemoval = "ExpertRemoval" - ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" - MoEChannelPruning = "MoEChannelPruning" - - -class LinearInitMode(Enum): - Random = "Random" - FromTeacher = "FromTeacher" - - -class HiddenSizeInitMode(Enum): - Random = "Random" - Truncate = "Truncate" - PruneByChannelRanking = "PruneByChannelRanking" - CopyAsIs = "CopyAsIs" - - IgnoreFn = Callable[[str], bool] default_ignore_fn: IgnoreFn = lambda _: False @@ -87,25 +68,52 @@ def print(s: str) -> None: def _process_single_layer( layer_idx: int, + pruning_mixin, + descriptor, parent_state_dict: dict, new_state_dict: dict, original_config: DeciLMConfig, new_config: DeciLMConfig, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, - mlp_init_config: dict[str, Any] | None, + mlp_init_config: Optional[dict[str, Any]], linear_init_mode: LinearInitMode, ignored_keys: set, keys: dict, is_original_mha: bool, head_size: int, hidden_size: int, -) -> tuple[dict[str, torch.Tensor], dict[str, str]]: - """Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). +) -> Tuple[Dict[str, torch.Tensor], Dict[str, str]]: + """ + Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). Thread-safe function for parallel layer processing. """ - layer_out_state_dict = {} keys_to_remove = {} + layer_out_state_dict = {} + + # Delegate to pruning_mixin if available + if pruning_mixin is not None: + _layer_out = pruning_mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) + return layer_out_state_dict, keys_to_remove + + # Legacy inline processing (fallback when no pruning_mixin) parent_block_config = original_config.block_configs[layer_idx] child_block_config = new_config.block_configs[layer_idx] @@ -119,13 +127,13 @@ def _process_single_layer( o_key = f"{attn_prefix}.o_proj.{part}" attn_keys = [q_key, k_key, v_key, o_key] # Drop attn keys that don't exist and required to be in the new state_dict - attn_keys = [key for key in attn_keys if key in new_state_dict] + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] if len(attn_keys) > 0 and all(key in keys for key in attn_keys): for key in attn_keys: keys_to_remove[key] = keys[key] if all(key not in ignored_keys for key in attn_keys): is_student_and_teacher_have_same_attention_implementation = all( - key in new_state_dict for key in attn_keys + key in new_state_dict.keys() for key in attn_keys ) if is_student_and_teacher_have_same_attention_implementation: if part == "weight": @@ -168,7 +176,7 @@ def _process_single_layer( else: linear_attn_key = f"{attn_prefix}.linear_attn.weight" - is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict + is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict.keys() if is_student_attn_replaced_with_linear: if linear_init_mode == LinearInitMode.Random: layer_out_state_dict[linear_attn_key] = new_state_dict[linear_attn_key] @@ -180,7 +188,7 @@ def _process_single_layer( raise ValueError(f"Unknown {linear_init_mode=}") else: # student attn random init - for new_key in new_state_dict: + for new_key in new_state_dict.keys(): if attn_prefix in new_key: layer_out_state_dict[new_key] = new_state_dict[new_key] @@ -190,7 +198,7 @@ def _process_single_layer( mlp_prefix = f"model.layers.{layer_idx}.mlp" linear_mlp_key = f"{mlp_prefix}.linear_mlp.weight" - is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict + is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict.keys() if is_student_mlp_replaced_with_linear: if linear_init_mode == LinearInitMode.Random: layer_out_state_dict[linear_mlp_key] = new_state_dict[linear_mlp_key] @@ -312,7 +320,7 @@ def _process_single_layer( ]: key_possibly_missing_in_student = f".{layer_idx}.{key_possibly_missing_in_student}" is_key_missing_from_student = ( - len([k for k in new_state_dict if key_possibly_missing_in_student in k]) == 0 + len([k for k in new_state_dict.keys() if key_possibly_missing_in_student in k]) == 0 ) if is_key_missing_from_student: for k in list(keys.keys()): @@ -324,6 +332,8 @@ def _process_single_layer( @torch.no_grad() def create_child_state_dict( + pruning_mixin, + descriptor, original_state_dict: dict, new_state_dict: dict, original_config: DeciLMConfig, @@ -331,12 +341,12 @@ def create_child_state_dict( gqa_init_mode: GQAInitMode, ignore_fn: IgnoreFn = default_ignore_fn, mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, - mlp_init_config: dict[str, Any] | None = None, - owned_block_indexes: set[int] | None = None, + mlp_init_config: Optional[dict[str, Any]] = None, + owned_block_indexes: Optional[set[int]] = None, linear_init_mode: LinearInitMode = LinearInitMode.Random, hidden_size_init_mode: HiddenSizeInitMode = HiddenSizeInitMode.CopyAsIs, - channel_importance_path: str | None = None, - max_layer_workers: int | None = None, # Now optional - will auto-calculate if None + channel_importance_path: Optional[str] = None, + max_layer_workers: Optional[int] = None, # Now optional - will auto-calculate if None ): mprint("=== Starting create_child_state_dict with optimizations ===") total_start_time = time.time() @@ -371,34 +381,40 @@ def create_child_state_dict( else: out_state_dict[key] = tensor - original_n_heads_in_group_per_layer = [ - b.attention.n_heads_in_group for b in original_config.block_configs + # Get language model config for LM-specific attributes (VL models have nested config) + original_lm_config = descriptor.get_language_model_config(original_config) + new_lm_config = descriptor.get_language_model_config(new_config) + + # Check if original model is MHA (all layers have num_key_value_heads == num_attention_heads) + original_num_kv_heads_per_layer = [ + b.attention.num_key_value_heads for b in original_config.block_configs ] - is_original_mha = set(original_n_heads_in_group_per_layer) == {1} - is_same_hidden_size = original_config.hidden_size == new_config.hidden_size - head_size = new_config.head_dim - orig_head_size = original_config.head_dim + num_attention_heads = original_lm_config.num_attention_heads + is_original_mha = all(kv == num_attention_heads for kv in original_num_kv_heads_per_layer) + is_same_hidden_size = original_lm_config.hidden_size == new_lm_config.hidden_size + head_size = _get_head_dim(new_lm_config) + orig_head_size = _get_head_dim(original_lm_config) assert head_size == orig_head_size, f"head_size {head_size} != orig_head_size {orig_head_size}" # Allow different hidden sizes for pruning if not is_same_hidden_size: - assert new_config.hidden_size <= original_config.hidden_size, ( - f"New hidden size ({new_config.hidden_size}) must be <= original ({original_config.hidden_size})" + assert new_lm_config.hidden_size <= original_lm_config.hidden_size, ( + f"New hidden size ({new_lm_config.hidden_size}) must be <= original ({original_lm_config.hidden_size})" ) assert hidden_size_init_mode != HiddenSizeInitMode.CopyAsIs, ( "Cannot copy as is when hidden sizes differ" ) - hidden_size = original_config.hidden_size + hidden_size = original_lm_config.hidden_size - ignored_keys = set([key for key in original_state_dict if ignore_fn(key)]) + ignored_keys = set([key for key in original_state_dict.keys() if ignore_fn(key)]) for key in ignored_keys: aprint(f"Ignoring key {key} and taking its init from new_state_dict") out_state_dict[key] = new_state_dict[key] keys = { match.group(1) if (match := re.search(r"(h\.\d+\..*)", key)) is not None else key: key - for key in original_state_dict + for key in original_state_dict.keys() } setup_time = time.time() - setup_start_time mprint(f"Phase 1 - Setup and memory pre-allocation: {setup_time:.2f}s") @@ -409,6 +425,8 @@ def create_child_state_dict( # Prepare arguments for parallel processing process_layer_partial = partial( _process_single_layer, + pruning_mixin=pruning_mixin, + descriptor=descriptor, parent_state_dict=original_state_dict, new_state_dict=new_state_dict, original_config=original_config, @@ -489,6 +507,7 @@ def create_child_state_dict( original_state_dict, new_config, original_config, + descriptor, hidden_size_init_mode, channel_importance_path, owned_block_indexes, @@ -527,7 +546,7 @@ def _generate_moe_keys(layer_idx: int, num_experts: int) -> tuple[str, dict[str, def _concatenate_experts_into_dense_ffn( original_state_dict: dict[str, torch.Tensor], - mlp_init_config: dict | None, + mlp_init_config: Optional[dict], hidden_size: int, layer_idx: int, child_block_config: BlockConfig, @@ -585,7 +604,8 @@ def _concatenate_experts_into_dense_ffn( "concat_dims and experts_weights must have the same keys" ) concat_routed_state_dict = { - name: torch.cat(experts_weights[name], dim=concat_dims[name]) for name in concat_dims + name: torch.cat(experts_weights[name], dim=concat_dims[name]) + for name in concat_dims.keys() } # turn the shared expert into a normal FFN. concatenate the pruned routed experts if needed. @@ -645,16 +665,16 @@ def _verify_state_dicts_match( def _init_mlp( *, - mlp_init_mode: MlpInitMode | str, + mlp_init_mode: Union[MlpInitMode, str], layer_idx: int, original_config: DeciLMConfig, - mlp_init_config: dict[str, Any] | None, + mlp_init_config: Optional[dict[str, Any]], original_state_dict: dict, new_state_dict: dict, new_config: DeciLMConfig, keys: dict[str, str], ignored_keys: set[str], - expert_idx: int | None = None, + expert_idx: Optional[int] = None, ) -> dict[str, torch.Tensor]: out_state_dict = {} @@ -679,10 +699,12 @@ def _init_mlp( projection_matrix = None for mlp_key in mlp_keys: expanded_dim = 1 if "down_proj" in mlp_key else 0 - if mlp_key in new_state_dict: + if mlp_key in new_state_dict.keys(): mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( mlp_init_mode, + mlp_prefix, expanded_dim, + layer_idx, new_state_dict[mlp_key], new_config, original_state_dict[mlp_key], @@ -690,7 +712,6 @@ def _init_mlp( mlp_init_config, pruned_filters, projection_matrix, - mlp_prefix, ) out_state_dict[mlp_key] = mlp_module_weight else: @@ -698,128 +719,6 @@ def _init_mlp( return out_state_dict -def _init_mlp_module( - mlp_init_mode: MlpInitMode | str, - expanded_dim: int, - new_item: torch.Tensor, - new_config: DeciLMConfig, - orig_item: torch.Tensor, - original_config: DeciLMConfig, - mlp_init_config: dict[str, Any] | None, - pruned_filters: torch.Tensor | None = None, - projection_matrix: dict[str, torch.Tensor] | None = None, - mlp_prefix: str | None = None, -) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor] | None]: - if isinstance(mlp_init_mode, str): - mlp_init_mode = MlpInitMode(mlp_init_mode) - assert orig_item.ndim == 2, f"{orig_item.ndim=}" - assert new_item.ndim == 2, f"{new_item.ndim=}" - - assert new_config.num_hidden_layers == original_config.num_hidden_layers, ( - f"({new_config.num_hidden_layers=}) != ({original_config.num_hidden_layers=})" - ) - - orig_ffn_size = orig_item.shape[expanded_dim] - new_ffn_size = new_item.shape[expanded_dim] - - if mlp_init_mode == MlpInitMode.CopyAsIs: - assert new_ffn_size == orig_ffn_size, ( - f"({new_ffn_size=}) != ({orig_ffn_size=}), can't be copied as is." - ) - mlp_module_weight = orig_item - - elif mlp_init_mode == MlpInitMode.Random: - mlp_module_weight = new_item - - elif new_ffn_size == orig_ffn_size: - mlp_module_weight = orig_item - - elif mlp_init_mode in ( - MlpInitMode.Truncate, - MlpInitMode.PruneByActivationsLog, - MlpInitMode.MoEChannelPruning, - ): - assert new_ffn_size <= orig_ffn_size, ( - f"({new_ffn_size=}) > ({orig_ffn_size=}), can't be truncated." - ) - - if mlp_init_mode == MlpInitMode.Truncate: - truncated_weight = torch.narrow( - orig_item, dim=expanded_dim, start=0, length=new_ffn_size - ) - mlp_module_weight = truncated_weight - - elif mlp_init_mode in (MlpInitMode.PruneByActivationsLog, MlpInitMode.MoEChannelPruning): - if pruned_filters is None: - filter_importance = _load_activations_log( - mlp_init_config, module_name=f"{mlp_prefix}.down_proj" - ) - filters_sorted_by_importance = torch.argsort(filter_importance, descending=True) - pruned_filters = filters_sorted_by_importance[:new_ffn_size].to(orig_item.device) - - pruned_weight = torch.index_select(orig_item, dim=expanded_dim, index=pruned_filters) - if mlp_init_config.get("scale_pruned_weights", False) and expanded_dim == 1: - pruned_weight = pruned_weight * (orig_ffn_size / new_ffn_size) - mlp_module_weight = pruned_weight - - elif ( - mlp_init_mode == MlpInitMode.ExpertRemoval - ): # the case of mlp layers of maverick. for now we only support copy as is - assert new_ffn_size == orig_ffn_size, ( - f"({new_ffn_size=}) != ({orig_ffn_size=}), can't be copied as is." - ) - mlp_module_weight = orig_item - - else: - raise ValueError(f"Unsupported {mlp_init_mode=}") - - return mlp_module_weight, pruned_filters, projection_matrix - - -def _init_moe_module( - *, - mlp_init_mode: MlpInitMode | str, - mlp_init_config: dict[str, Any] | None, - layer_idx: int, - orig_router_weight: torch.Tensor, - orig_experts_weights: dict[str, list[torch.Tensor]], - new_router_weight: torch.Tensor, - new_experts_weights: dict[str, list[torch.Tensor]], -) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor] | None]: - if isinstance(mlp_init_mode, str): - mlp_init_mode = MlpInitMode(mlp_init_mode) - - if mlp_init_mode == MlpInitMode.ExpertRemoval: - result_router_weight, result_experts_weights = _prune_experts_by_score( - mlp_init_config=mlp_init_config, - layer_idx=layer_idx, - orig_router_weight=orig_router_weight, - orig_experts_weights=orig_experts_weights, - new_num_experts=new_router_weight.shape[0], - ) - else: - raise ValueError(f"Unsupported {mlp_init_mode=}") - - assert result_router_weight.shape == new_router_weight.shape - assert result_experts_weights.keys() == new_experts_weights.keys(), ( - "result_experts_weights and new_experts_weights must have the same keys" - ) - assert all( - len(new_experts_weights[name]) == len(result_experts_weights[name]) - for name in result_experts_weights.keys() - ) - assert all( - all( - new_expert_weight.shape == result_expert_weight.shape - for new_expert_weight, result_expert_weight in zip( - new_experts_weights[name], result_experts_weights[name] - ) - ) - for name in result_experts_weights.keys() - ) - return result_router_weight, result_experts_weights - - def _prune_experts_by_score( *, mlp_init_config: dict[str, Any], @@ -848,377 +747,6 @@ def _prune_experts_by_score( return result_router_weight, result_experts_weights -def _load_expert_scores(mlp_init_config: dict[str, Any] | None) -> list[list[int | float]]: - assert mlp_init_config is not None - if "expert_scores_file" in mlp_init_config: - expert_scores_file = mlp_init_config["expert_scores_file"] - with open(expert_scores_file) as f: - expert_scores = json.load(f) - elif "activations_log_dir" in mlp_init_config: - _cache_activations_log(mlp_init_config) - num_layers = len(ACTIVATIONS_LOG) - expert_scores = [] - for layer_idx in range(num_layers): - router_name = f"model.layers.{layer_idx}.mlp.router" - expert_scores.append(ACTIVATIONS_LOG[router_name]["expert_ranks"]) - expert_scores = torch.stack(expert_scores) - expert_scores = expert_scores.tolist() - else: - raise ValueError(f"Unsupported {mlp_init_config=}") - return expert_scores - - -ACTIVATIONS_LOG = dict() - - -def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None: - if len(ACTIVATIONS_LOG) == 0: - assert "activations_log_dir" in mlp_init_config - activations_log_dir = mlp_init_config["activations_log_dir"] - ACTIVATIONS_LOG.update( - { - module_name: module_log - for p in Path(activations_log_dir).glob("rank*.pth") - for module_name, module_log in torch.load(p).items() - } - ) - - -def _load_activations_log(mlp_init_config: dict[str, Any], module_name: str) -> torch.Tensor: - _cache_activations_log(mlp_init_config) - module_log = ACTIVATIONS_LOG[module_name] - filter_importance = module_log["score"] - return filter_importance - - -def _init_attention_weights( - gqa_init_mode, - layer_idx, - new_state_dict, - new_config, - original_state_dict, - q_key, - k_key, - v_key, - o_key, - original_config, - is_original_mha, - head_size, - mlp_init_config, -): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" - ) - num_q_heads = new_config.num_attention_heads - n_heads_in_group = new_config.block_configs[layer_idx].attention.n_heads_in_group - orig_n_heads_in_group = original_config.block_configs[layer_idx].attention.n_heads_in_group - num_kv_heads = num_q_heads // n_heads_in_group - orig_num_kv_heads = num_q_heads // orig_n_heads_in_group - - # new_w* are typically randomly initialized - new_wq = new_state_dict[q_key] - new_wk = new_state_dict[k_key] - new_wv = new_state_dict[v_key] - new_wo = new_state_dict[o_key] - - # w* are from the parent model - wq = original_state_dict[q_key] - wk = original_state_dict[k_key] - wv = original_state_dict[v_key] - wo = original_state_dict[o_key] - - if "bias" in k_key: - for tensor in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo]: - assert tensor.ndim == 1 - tensor.unsqueeze_(1) - dim1 = wk.shape[1] # this is the hidden_size in case of matrix weights, and 1 in case of biases - - if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock): - wk, wv = new_wk, new_wv - elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV): - assert n_heads_in_group % orig_n_heads_in_group == 0, ( - f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" - ) - n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group - - wk = wk.view(-1, n_heads_to_aggregate, head_size, dim1) - wv = wv.view(-1, n_heads_to_aggregate, head_size, dim1) - - if gqa_init_mode == GQAInitMode.AverageKV: - wk = wk.mean(dim=1) - wv = wv.mean(dim=1) - else: - wk = wk[:, 0] - wv = wv[:, 0] - elif gqa_init_mode == GQAInitMode.CopyAsIs: - assert new_wk.shape == wk.shape, f"({new_wk.shape=}) != ({wk.shape=})" - assert new_wv.shape == wv.shape, f"({new_wv.shape=}) != ({wv.shape=})" - assert new_wq.shape == wq.shape, f"({new_wq.shape=}) != ({wq.shape=})" - assert new_wo.shape == wo.shape, f"({new_wo.shape=}) != ({wo.shape=})" - - elif gqa_init_mode == GQAInitMode.Degrouping: - assert not is_original_mha, ( - "Degrouping can only be done on original models that are GQA themselves." - ) - n_groups = new_config.num_attention_heads // n_heads_in_group - orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group - assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" - n_repeats = n_groups // orig_n_groups - if n_repeats > 1: - print(f"Degrouping {orig_n_groups} into {n_groups}") - - def degroup_w(w): - w = w.view(orig_n_groups, head_size, dim1) - w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) - w = w.reshape(n_groups * head_size, dim1) - return w - - wk = degroup_w(wk) - wv = degroup_w(wv) - - elif gqa_init_mode == GQAInitMode.PruneKVHeads: - wk = wk.view(orig_num_kv_heads, head_size, dim1) - wv = wv.view(orig_num_kv_heads, head_size, dim1) - wq = wq.view(orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1) - wo = wo.view(dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size) - - o_proj_module_name = o_key.replace(".weight", "") - kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) - kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) - kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] - kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] - - wk = wk[kv_heads_to_keep] - wv = wv[kv_heads_to_keep] - - reduction_factor = orig_num_kv_heads // num_kv_heads - - prune_via_duplication = False - if prune_via_duplication: - ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. - wq = wq[kv_heads_to_keep] - wq = torch.repeat_interleave(wq, repeats=reduction_factor, dim=0) - - ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. - ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. - wo = wo[:, kv_heads_to_keep] - wo = torch.repeat_interleave(wo, repeats=reduction_factor, dim=1) - wo = wo / reduction_factor - - else: # prune via zeroing out - ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. - ## We need to interleave them to keep the matching between queries and kv heads. - kv_heads_to_keep = kv_heads_to_keep.tolist() - kv_heads_to_remove = kv_heads_to_remove.tolist() - kv_head_ordering = [] - zero_out_mask = [] - for i_head in range(orig_num_kv_heads): - if i_head % reduction_factor == 0: - kv_head_ordering.append(kv_heads_to_keep.pop(0)) - zero_out_mask.append(False) - else: - kv_head_ordering.append(kv_heads_to_remove.pop(0)) - zero_out_mask.append(True) - - wq = wq[kv_head_ordering] - - ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. - ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. - ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. - ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. - wo = wo[:, kv_head_ordering] - wo[:, zero_out_mask] = 0.0 - - else: - raise ValueError(f"{gqa_init_mode=} not supported") - - wk = wk.reshape(-1, dim1) - wv = wv.reshape(-1, dim1) - wq = wq.reshape(-1, dim1) - wo = wo.reshape(dim1, -1) - return wq, wk, wv, wo - - -def _init_attention_biases( - gqa_init_mode, - layer_idx, - new_state_dict, - new_config: DeciLMConfig, - original_state_dict, - q_key, - k_key, - v_key, - o_key, - original_config, - is_original_mha, - head_size, - mlp_init_config, -): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" - ) - num_q_heads = new_config.num_attention_heads - n_heads_in_group = new_config.block_configs[layer_idx].attention.n_heads_in_group - orig_n_heads_in_group = original_config.block_configs[layer_idx].attention.n_heads_in_group - num_kv_heads = num_q_heads // n_heads_in_group - orig_num_kv_heads = num_q_heads // orig_n_heads_in_group - - o_proj_bias = new_config.o_proj_bias - attention_bias = new_config.attention_bias - - # If no biases - if not (o_proj_bias or attention_bias): - return {} - - new_bias_sd = {} - bias_sd = {} - # new_w* are typically randomly initialized - if o_proj_bias: - new_bias_sd["o"] = new_state_dict[o_key] - bias_sd["o"] = original_state_dict[o_key] - if attention_bias: - for bias_key, key in zip("qkv", [q_key, k_key, v_key]): - new_bias_sd[bias_key] = new_state_dict[key] - bias_sd[bias_key] = original_state_dict[key] - - # maybe unsqueeze all tensors - for tensor in list(new_bias_sd.values()) + list(bias_sd.values()): - assert tensor.ndim == 1 - tensor.unsqueeze_(1) - - dim1 = 1 # this is the hidden_size in case of matrix weights, and 1 in case of biases - if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock) and attention_bias: - bias_sd["k"] = torch.zeros( - new_bias_sd["k"].shape, dtype=bias_sd["k"].dtype, device=bias_sd["k"].device - ) - bias_sd["v"] = torch.zeros( - new_bias_sd["v"].shape, dtype=bias_sd["v"].dtype, device=bias_sd["v"].device - ) - elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV) and attention_bias: - assert n_heads_in_group % orig_n_heads_in_group == 0, ( - f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" - ) - n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group - - bias_sd["k"] = bias_sd["k"].view(-1, n_heads_to_aggregate, head_size, dim1) - bias_sd["v"] = bias_sd["v"].view(-1, n_heads_to_aggregate, head_size, dim1) - - if gqa_init_mode == GQAInitMode.AverageKV: - bias_sd["k"] = bias_sd["k"].mean(dim=1) - bias_sd["v"] = bias_sd["v"].mean(dim=1) - else: - bias_sd["k"] = bias_sd["k"][:, 0] - bias_sd["v"] = bias_sd["v"][:, 0] - elif gqa_init_mode == GQAInitMode.CopyAsIs: - for key in bias_sd: - assert new_bias_sd[key].shape == bias_sd[key].shape, ( - f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})" - ) - - elif gqa_init_mode == GQAInitMode.Degrouping and attention_bias: - assert not is_original_mha, ( - "Degrouping can only be done on original models that are GQA themselves." - ) - n_groups = new_config.num_attention_heads // n_heads_in_group - orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group - assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" - n_repeats = n_groups // orig_n_groups - if n_repeats > 1: - print(f"Degrouping {orig_n_groups} into {n_groups}") - - def degroup_w(w): - w = w.view(orig_n_groups, head_size, dim1) - w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) - w = w.reshape(n_groups * head_size, dim1) - return w - - bias_sd["k"] = degroup_w(bias_sd["k"]) - bias_sd["v"] = degroup_w(bias_sd["v"]) - - elif gqa_init_mode == GQAInitMode.PruneKVHeads: - if o_proj_bias: - o_proj_module_name = o_key.rsplit(".", 1)[0] - else: - # Here we assume that the o_proj layer is called "o_proj" - o_proj_module_name = k_key.rsplit(".", 2)[0] + ".o_proj" - - kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) - kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) - kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] - kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] - - # view as KV groups - if attention_bias: - bias_sd["k"] = bias_sd["k"].view(orig_num_kv_heads, head_size, dim1) - bias_sd["v"] = bias_sd["v"].view(orig_num_kv_heads, head_size, dim1) - bias_sd["q"] = bias_sd["q"].view( - orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1 - ) - # Keep important KV heads and prune the others - bias_sd["k"] = bias_sd["k"][kv_heads_to_keep] - bias_sd["v"] = bias_sd["v"][kv_heads_to_keep] - if o_proj_bias: - bias_sd["o"] = bias_sd["o"].view( - dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size - ) - - reduction_factor = orig_num_kv_heads // num_kv_heads - - prune_via_duplication = False - if prune_via_duplication: - if attention_bias: - ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. - bias_sd["q"] = bias_sd["q"][kv_heads_to_keep] - bias_sd["q"] = torch.repeat_interleave( - bias_sd["q"], repeats=reduction_factor, dim=0 - ) - - if o_proj_bias: - ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. - ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. - bias_sd["o"] = bias_sd["o"][:, kv_heads_to_keep] - bias_sd["o"] = torch.repeat_interleave( - bias_sd["o"], repeats=reduction_factor, dim=1 - ) - bias_sd["o"] = bias_sd["o"] / reduction_factor - - else: # prune via zeroing out - ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. - ## We need to interleave them to keep the matching between queries and kv heads. - kv_heads_to_keep = kv_heads_to_keep.tolist() - kv_heads_to_remove = kv_heads_to_remove.tolist() - kv_head_ordering = [] - zero_out_mask = [] - for i_head in range(orig_num_kv_heads): - if i_head % reduction_factor == 0: - kv_head_ordering.append(kv_heads_to_keep.pop(0)) - zero_out_mask.append(False) - else: - kv_head_ordering.append(kv_heads_to_remove.pop(0)) - zero_out_mask.append(True) - - if attention_bias: - bias_sd["q"] = bias_sd["q"][kv_head_ordering] - - if o_proj_bias: - ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. - ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. - ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. - ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. - bias_sd["o"] = bias_sd["o"][:, kv_head_ordering] - bias_sd["o"][:, zero_out_mask] = 0.0 - - else: - raise ValueError(f"{gqa_init_mode=} not supported") - - if attention_bias: - for bias_key in "qkv": - bias_sd[bias_key] = bias_sd[bias_key].reshape(-1) - if o_proj_bias: - bias_sd["o"] = bias_sd["o"].reshape(-1) - return bias_sd - - def _init_linear_attn( parent_state_dict: dict[str, torch.Tensor], parent_config: DeciLMConfig, @@ -1226,13 +754,15 @@ def _init_linear_attn( v_key: str, o_key: str, ) -> torch.Tensor: - """Init a linear layer that operates like an attention layer that assigns score 1 to the current token + """ + Init a linear layer that operates like an attention layer that assigns score 1 to the current token and score 0 to all others: out = (Wo @ Wv) @ x """ n_embd = parent_config.hidden_size - head_size = parent_config.head_dim - n_heads_in_group = parent_config.block_configs[layer_idx].attention.n_heads_in_group - n_kv_heads = parent_config.num_attention_heads // n_heads_in_group + head_size = _get_head_dim(parent_config) + # Get num_kv_heads from config, compute n_heads_in_group + n_kv_heads = parent_config.block_configs[layer_idx].attention.num_key_value_heads + n_heads_in_group = parent_config.num_attention_heads // n_kv_heads wv = parent_state_dict[v_key] wv = wv.view(n_kv_heads, head_size, n_embd) @@ -1245,7 +775,9 @@ def _init_linear_attn( def _init_linear_mlp(teacher_mlp_state_dict: dict[str, torch.Tensor]) -> torch.Tensor: - """A linear layer that does (W_down @ W_up) @ x, ignoring W_gate.""" + """ + A linear layer that does (W_down @ W_up) @ x, ignoring W_gate. + """ if "linear_mlp.weight" in teacher_mlp_state_dict: # if the teacher itself is a linear layer return teacher_mlp_state_dict["linear_mlp.weight"] @@ -1314,9 +846,10 @@ def _parse_model_config_overrides( model_config_overrides_json: str | dict | Path | list[dict], n_layer: int, ) -> list[dict[str, Any]]: - """Example model_config_overrides_json: + """ + example model_config_overrides_dict: { - "attention": [{"n_heads_in_group": 2}], + "attention": [{"num_key_value_heads": 4}], "ffn": [{"intermediate_size": 14336}] } """ @@ -1362,18 +895,24 @@ def _apply_hidden_size_pruning( original_state_dict: dict[str, torch.Tensor], new_config: DeciLMConfig, original_config: DeciLMConfig, + descriptor, hidden_size_init_mode: HiddenSizeInitMode, - channel_importance_path: str | None = None, - owned_block_indexes: list[int] | None = None, + channel_importance_path: Optional[str] = None, + owned_block_indexes: Optional[list[int]] = None, ) -> dict[str, torch.Tensor]: - """Apply hidden size pruning to all layers that depend on hidden_size. + """ + Apply hidden size pruning to all layers that depend on hidden_size. This includes embeddings, layer norms, and any linear layers that haven't been handled yet. """ if isinstance(hidden_size_init_mode, str): hidden_size_init_mode = HiddenSizeInitMode(hidden_size_init_mode) - original_hidden_size = original_config.hidden_size - new_hidden_size = new_config.hidden_size + # Get language model config (for VL models this extracts the nested config) + original_lm_config = descriptor.get_language_model_config(original_config) + new_lm_config = descriptor.get_language_model_config(new_config) + + original_hidden_size = original_lm_config.hidden_size + new_hidden_size = new_lm_config.hidden_size if hidden_size_init_mode == HiddenSizeInitMode.CopyAsIs: return out_state_dict @@ -1381,7 +920,7 @@ def _apply_hidden_size_pruning( # Load channel ranking if needed if hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: if channel_importance_path is not None: - with open(channel_importance_path) as f: + with open(channel_importance_path, "r") as f: channel_ranking = json.load(f)["channel_importance_ranking"] else: raise ValueError( @@ -1574,10 +1113,12 @@ def _prune_hidden_size_dimension( original_tensor: torch.Tensor, new_hidden_size: int, hidden_size_init_mode: HiddenSizeInitMode, - channel_ranking: list[int] | None = None, + channel_ranking: Optional[list[int]] = None, dim: int = -1, ) -> torch.Tensor: - """Prune a tensor along the specified dimension to match the new hidden size.""" + """ + Prune a tensor along the specified dimension to match the new hidden size. + """ original_size = original_tensor.shape[dim] if hidden_size_init_mode == HiddenSizeInitMode.Random: @@ -1627,3 +1168,14 @@ def _prune_hidden_size_dimension( else: raise ValueError(f"Unsupported hidden_size_init_mode: {hidden_size_init_mode}") + + +def _get_head_dim(config) -> int: + """Get head dimension from config in a model-agnostic way. + + Some models like Llama have `head_dim` as a direct attribute, while others + like Qwen2 don't. This helper computes it from hidden_size and num_attention_heads. + """ + if hasattr(config, "head_dim") and config.head_dim is not None: + return config.head_dim + return config.hidden_size // config.num_attention_heads diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index f52c12d26f..3c3b54830a 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -14,11 +14,13 @@ # limitations under the License. # mypy: ignore-errors -"""Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, +""" +Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, particularly for DeciLM models. """ import concurrent.futures +import dataclasses import fcntl import os import shutil @@ -31,9 +33,12 @@ import torch from safetensors.torch import save_file as safe_save_file +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from modelopt.torch.puzzletron.decilm import deci_lm_hf_code +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import maybe_cast_block_configs from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from modelopt.torch.puzzletron.tools.common import infer_weights_dtype @@ -69,7 +74,8 @@ def load_checkpoint( model_config_overrides: dict | None = None, ignore_unexpected_config_keys: bool = False, ) -> DeciLMForCausalLM: - """Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your + """ + Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your local repo code, not the code inside the checkpoint. """ from modelopt.torch.puzzletron.tools.checkpoint_utils import ( @@ -99,20 +105,54 @@ def load_checkpoint( return model +def force_cache_dynamic_modules( + config: PretrainedConfig, checkpoint_dir: Path | str, trust_remote_code: bool = False +): + has_remote_code = ( + hasattr(config, "auto_map") + and isinstance(config.auto_map, dict) + and "AutoConfig" in config.auto_map.keys() + ) + if has_remote_code and trust_remote_code: + for class_reference in config.auto_map.values(): + _ = get_class_from_dynamic_module(class_reference, checkpoint_dir) + + def load_model_config( checkpoint_dir: Path | str, model_config_overrides: Mapping | None = None, ignore_unexpected_config_keys: bool = False, -) -> DeciLMConfig: + trust_remote_code: bool = False, +): + """Load model configuration from a checkpoint directory. + + Args: + checkpoint_dir: Path to the checkpoint directory (e.g. containing config.json). + model_config_overrides: Optional mapping of config overrides. + ignore_unexpected_config_keys: If True, ignore unexpected config keys. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + Loaded model configuration (PretrainedConfig). + """ if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) if model_config_overrides is None: model_config_overrides = {} - config, unused_kwargs = DeciLMConfig.from_pretrained( - checkpoint_dir, return_unused_kwargs=True, **model_config_overrides + config, unused_kwargs = AutoConfig.from_pretrained( + checkpoint_dir, + trust_remote_code=trust_remote_code, + return_unused_kwargs=True, + **model_config_overrides, ) + if hasattr(config, "block_configs"): + config.block_configs = maybe_cast_block_configs(config.block_configs) + + force_cache_dynamic_modules(config, checkpoint_dir, trust_remote_code=trust_remote_code) if not ignore_unexpected_config_keys: if unused_kwargs: @@ -121,73 +161,64 @@ def load_model_config( return config -def save_checkpoint(model: DeciLMForCausalLM, checkpoint_dir: Path | str) -> None: - _save_checkpoint(model.config, model.state_dict(), checkpoint_dir) +def save_checkpoint( + model: PreTrainedModel, + checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", +) -> None: + _save_checkpoint(model.config, model.state_dict(), checkpoint_dir, descriptor) def _save_checkpoint( - model_config: DeciLMConfig, + model_config: PretrainedConfig, state_dict: dict[str, torch.Tensor], checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", max_workers: int | None = None, # Now optional - will auto-calculate if None ) -> None: - mprint("=== Starting _save_checkpoint detailed profiling ===") - total_start_time = time.time() + from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) - # Phase 1: Create directory and save config - phase1_start_time = time.time() checkpoint_dir.mkdir(parents=True, exist_ok=True) - model_config.save_pretrained(checkpoint_dir) - phase1_time = time.time() - phase1_start_time - mprint(f"Phase 1 - Directory creation and config save: {phase1_time:.2f}s") - # Phase 2: Save subblocks (main model weights) with auto-calculated worker count - phase2_start_time = time.time() + # Phase 1: Save config + save_model_config(model_config, checkpoint_dir) + + # Phase 2: Build weight map using descriptor and write index + subblock_keys = descriptor.get_weight_groups( + layer_names=state_dict.keys(), + num_hidden_layers=model_config.num_hidden_layers, + ) + + weight_map = {} + for subblock, layer_keys in subblock_keys.items(): + weight_map_entries = { + key: f"subblocks_safetensors/{subblock}.safetensors" for key in layer_keys + } + weight_map.update(weight_map_entries) + + # Handle tie_word_embeddings - remove from state_dict and weight_map BEFORE writing index + output_emb_weight_name = f"{descriptor.output_embedding_name()}.weight" + if getattr(model_config, "tie_word_embeddings", False) and output_emb_weight_name in state_dict: + state_dict = {k: v for k, v in state_dict.items() if k != output_emb_weight_name} + weight_map = {k: v for k, v in weight_map.items() if k != output_emb_weight_name} + + # Write index (now without tied embedding) + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME + index_json = json_dumps(index) + _write_file_process_safe(index_json, index_path) + + # Phase 3: Save subblocks save_subblocks( state_dict, checkpoint_dir, + weight_map=weight_map, multi_threaded=True, - max_workers=max_workers, # Will auto-calculate if None + max_workers=max_workers, ) - phase2_time = time.time() - phase2_start_time - mprint(f"Phase 2 - Save subblocks (model weights): {phase2_time:.2f}s") - - # Phase 3: Save safetensors index - phase3_start_time = time.time() - save_safetensors_index(model_config, checkpoint_dir) - phase3_time = time.time() - phase3_start_time - mprint(f"Phase 3 - Save safetensors index: {phase3_time:.2f}s") - - # Phase 4: Copy HF code - phase4_start_time = time.time() - copy_deci_lm_hf_code(checkpoint_dir) - phase4_time = time.time() - phase4_start_time - mprint(f"Phase 4 - Copy HF code: {phase4_time:.2f}s") - - total_time = time.time() - total_start_time - mprint(f"=== _save_checkpoint completed in {total_time:.2f}s ===") - mprint( - f"Breakdown: Config {phase1_time:.1f}s + Subblocks {phase2_time:.1f}s + " - f"Index {phase3_time:.1f}s + HF code {phase4_time:.1f}s" - ) - mprint( - f"Save percentage breakdown: Config {phase1_time / total_time * 100:.1f}% + " - f"Subblocks {phase2_time / total_time * 100:.1f}% + " - f"Index {phase3_time / total_time * 100:.1f}% + " - f"HF code {phase4_time / total_time * 100:.1f}%" - ) - - # Performance metrics - if phase2_time > 0: - subblocks_percentage = phase2_time / total_time * 100 - actual_workers = max_workers if max_workers else "auto" - mprint( - f"I/O optimization: Subblocks were {subblocks_percentage:.1f}% of total save time " - f"(max_workers={actual_workers})" - ) def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: @@ -210,6 +241,7 @@ def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: def save_subblocks( state_dict: dict[str, torch.Tensor], checkpoint_dir: Path | str, + weight_map: dict[str, str] | None = None, multi_threaded: bool = True, max_workers: int | None = None, # Now optional - will auto-calculate if None ) -> None: @@ -219,14 +251,15 @@ def save_subblocks( if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) - # Step 1: Build weight map + # Step 1: Build weight map (use provided or build from state_dict) weight_map_start_time = time.time() - weight_map = _build_safetensors_weight_map( - state_dict=state_dict, - non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, - module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, - layers_module_name=LAYERS_MODULE_NAME, - ) + if weight_map is None: + weight_map = _build_safetensors_weight_map( + state_dict=state_dict, + non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, + module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, + layers_module_name=LAYERS_MODULE_NAME, + ) weight_name_to_filename = {k: checkpoint_dir / v for k, v in weight_map.items()} weight_map_time = time.time() - weight_map_start_time mprint(f" Step 1 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") @@ -323,6 +356,7 @@ def save_safetensors_index( model_config: DeciLMConfig, checkpoint_dir: Path | str, ) -> None: + """Save safetensors index for DeciLM models (legacy function).""" mprint("=== Starting save_safetensors_index profiling ===") index_start_time = time.time() @@ -372,7 +406,8 @@ def _write_file_process_safe( path: Path | str, write_fn: Callable[[Any, BinaryIO], None] = _write_text, ) -> None: - """Write a file in a multi-process safe way. + """ + Write a file in a multi-process safe way. If another process tries to write the same file using this method, the current process "gives up" and assumes that the matter is being taken care of by another process. @@ -435,13 +470,19 @@ def _build_safetensors_weight_map( return weight_map -# Not really needed -def save_model_config(model_config: DeciLMConfig, checkpoint_dir: Path | str) -> None: +def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: + if hasattr(model_config, "block_configs"): + model_config.block_configs = [ + dataclasses.asdict(conf) if dataclasses.is_dataclass(conf) else conf + for conf in model_config.block_configs + ] model_config.save_pretrained(checkpoint_dir) def copy_deci_lm_hf_code(output_dir: Path | str) -> None: - """Copy the deci_lm_hf_code directory to the output directory.""" + """ + Copy the deci_lm_hf_code directory to the output directory. + """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) code_dir = Path(deci_lm_hf_code.__file__).parent diff --git a/modelopt/torch/puzzletron/utils/dummy_modules.py b/modelopt/torch/puzzletron/utils/dummy_modules.py new file mode 100644 index 0000000000..c9eaa2bc6c --- /dev/null +++ b/modelopt/torch/puzzletron/utils/dummy_modules.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from typing_extensions import override + + +class DummyModule(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) + + @staticmethod + def load_state_dict_post_hook( + module: torch.nn.Module, + incompatible_keys: torch.nn.modules.module._IncompatibleKeys, + ) -> None: + incompatible_keys.missing_keys.clear() + incompatible_keys.unexpected_keys.clear() + + +class DummyBlock(DummyModule): + def __init__(self, block_index: int): + super().__init__() + self.block_index = block_index + + @override + def forward( + self, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, None]: + return x + + +class DummyWTE(DummyModule): + def __init__(self, hidden_size: int, dtype: Optional[torch.dtype] = None): + super().__init__() + self.n_embd = hidden_size + self.dtype = dtype + + @override + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + B, T = input_ids.shape + result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device) + return result + + +class DummyLMHead(DummyModule): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.vocab_size = config.vocab_size + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.shape + result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device) + return result diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index 6c9feecd0d..07d1565f42 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -19,14 +19,24 @@ import torch from datasets import Dataset, DatasetDict -from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers +# Path to HF configs relative to this file +# HF configs are in tests/gpu/torch/puzzletron/resources/hf_configs +HF_CONFIGS_DIR = ( + Path(__file__).parent.parent.parent.parent / "gpu/torch/puzzletron/resources/hf_configs" +) + def setup_test_model_and_data( - project_root_path: Path, tmp_path: Path, rank: int + project_root_path: Path, + tmp_path: Path, + rank: int, + hf_config_name: str, + hybrid_override_pattern: str | None = None, ) -> tuple[Path, Path, Path]: """ Setup the test model and data for the puzzletron NAS search. @@ -35,10 +45,12 @@ def setup_test_model_and_data( project_root_path (Path): the root path of the project tmp_path (Path): the temporary path to use for the test rank (int): the rank of the process + hf_config_name (str): Name of the HF config directory (e.g., "llama_3_1_8b_instruct") + hybrid_override_pattern (str): For NemotronH models, the layer type pattern Returns: tuple[Path, Path, Path]: - the puzzle_dir, llama_checkpoint_path, dataset_path + the puzzle_dir, hf_checkpoint_path, dataset_path """ # Register Hydra custom resolvers (needed for config resolution) @@ -46,8 +58,8 @@ def setup_test_model_and_data( # The inputs for the nas.convert() step. # - puzzle_dir = tmp_path - llama_checkpoint_path = puzzle_dir / "input_model/llama" + puzzle_dir = tmp_path / hf_config_name + hf_checkpoint_path = puzzle_dir / f"hf_models/{hf_config_name}" dataset_path = puzzle_dir / "dummy_dataset" if rank == 0: @@ -55,74 +67,133 @@ def setup_test_model_and_data( setup_puzzle_dir(puzzle_dir) save_dummy_dataset(dataset_path) - # Create a small Llama model + # Create a small HF model tokenizer = create_tokenizer(project_root_path) - create_and_save_small_llama_model( - llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer + create_and_save_small_hf_model( + output_path=str(hf_checkpoint_path), + vocab_size=tokenizer.vocab_size, + tokenizer=tokenizer, + hf_config_name=hf_config_name, + hybrid_override_pattern=hybrid_override_pattern, ) dist.barrier() return ( puzzle_dir, - llama_checkpoint_path, + hf_checkpoint_path, dataset_path, ) -def create_and_save_small_llama_model( - output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase +def create_and_save_small_hf_model( + output_path: str, + vocab_size: int, + tokenizer: PreTrainedTokenizerBase, + hf_config_name: str, + hybrid_override_pattern: str | None = None, ): """ - Create and save a small Llama model for testing the conversion pipeline. - This mimics having a real Llama checkpoint that needs to be converted. + Create and save a small HuggingFace model for testing the conversion pipeline. + Uses real HuggingFace config to preserve model-specific settings (like tie_word_embeddings), + but shrinks size parameters for fast testing. + + Args: + output_path: Where to save the model + vocab_size: Vocabulary size (should match tokenizer) + tokenizer: Tokenizer to save alongside the model + hf_config_name: Name of the config directory under resources/hf_configs/ + e.g., "llama_3_1_8b_instruct", "llama_3_2_3b_instruct", or "qwen2_5_7b_instruct" + hybrid_override_pattern: For NemotronH models, the layer type pattern (e.g., "*-" for Attention+MLP, + "M-" for Mamba+MLP). Must match num_hidden_layers. None for non-NemotronH models. """ os.makedirs(output_path, exist_ok=True) - # Create a minimal Llama config (small for testing) + # Load real HuggingFace config (preserves tie_word_embeddings, rope_scaling, etc.) + config_path = HF_CONFIGS_DIR / hf_config_name + config = AutoConfig.from_pretrained(config_path, local_files_only=True, trust_remote_code=True) + + # Override size-related params to make it small for testing # Note: intermediate_size must be divisible by 256 per DeciLM config requirements # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility - llama_config = LlamaConfig( - vocab_size=vocab_size, - hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) - intermediate_size=512, # Must be divisible by 256 - num_hidden_layers=2, - num_attention_heads=32, # Matches original test - num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) - max_position_embeddings=512, - rms_norm_eps=1e-5, - rope_theta=10000.0, - attention_bias=False, - hidden_act="silu", - tie_word_embeddings=False, - ) - # Create and save the Llama model - model = LlamaForCausalLM(llama_config) + # VL models have nested configs (text_config, vision_config) + if hf_config_name == "qwen3-vl-30b-a3b-instruct": + config.text_config.vocab_size = vocab_size + config.text_config.hidden_size = 256 + config.text_config.intermediate_size = 512 + config.text_config.num_hidden_layers = 2 + config.text_config.num_attention_heads = 32 + config.text_config.num_key_value_heads = 8 + config.text_config.num_experts = 16 # Reduce from 128 + config.text_config.moe_intermediate_size = 256 + config.text_config.max_position_embeddings = 512 + config.vision_config.depth = 2 # Reduce from 27 + config.vision_config.hidden_size = 256 + config.vision_config.intermediate_size = 512 + config.vision_config.out_hidden_size = 256 + # TODO: this is hack, redesign converter to not read config.num_hidden_layers directly. + # set top-level num_hidden_layers for converter compatibility + config.num_hidden_layers = config.text_config.num_hidden_layers + else: + # Regular models have flat config + config.vocab_size = vocab_size + config.hidden_size = 256 + config.intermediate_size = 512 + config.num_hidden_layers = 2 + config.num_attention_heads = 32 + config.num_key_value_heads = 8 + config.max_position_embeddings = 512 + + # Fix layer_types to match num_hidden_layers (newer transformers validates this) + if hasattr(config, "layer_types") and config.layer_types is not None: + config.layer_types = config.layer_types[:2] + + # Fix rope_scaling to be consistent with max_position_embeddings + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + config.rope_scaling["original_max_position_embeddings"] = 256 + + # NemotronH requires hybrid_override_pattern to match num_hidden_layers + if hasattr(config, "hybrid_override_pattern") and hybrid_override_pattern is not None: + config.hybrid_override_pattern = hybrid_override_pattern + + # Set seed for reproducible weight initialization + torch.manual_seed(42) + + # Create and save the model + # TODO: Consider using AutoModel.from_config instead. + if hf_config_name == "qwen3-vl-30b-a3b-instruct": + from transformers import Qwen3VLMoeForConditionalGeneration + + model = Qwen3VLMoeForConditionalGeneration._from_config(config) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.to(dtype=torch.bfloat16).save_pretrained(output_path) # Save tokenizer tokenizer.save_pretrained(output_path) # Save config - llama_config.save_pretrained(output_path) + config.save_pretrained(output_path) def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: """ - Create a tokenizer for the Llama model. + Create a tokenizer for the model. """ - tokenizer_path = project_root_path / "tests/_test_utils/torch/puzzletron/resources/tokenizer" + tokenizer_path = project_root_path / "tests/gpu/torch/puzzletron/resources/tokenizer" tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) return tokenizer -def setup_puzzle_dir(puzzle_dir: str): +def setup_puzzle_dir(puzzle_dir: str | Path): """ Setup puzzle directory by removing existing directory and creating a new one. """ - if Path(puzzle_dir).exists(): + puzzle_dir = Path(puzzle_dir) + if puzzle_dir.exists(): shutil.rmtree(puzzle_dir) - Path(puzzle_dir).mkdir(parents=True, exist_ok=True) + puzzle_dir.mkdir(parents=True, exist_ok=True) def save_dummy_dataset(dataset_path: Path | str): diff --git a/tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py b/tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py deleted file mode 100644 index 4b1ea0b414..0000000000 --- a/tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py +++ /dev/null @@ -1,50 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from pathlib import Path - -from _test_utils.torch.puzzletron.utils import create_and_save_small_llama_model, create_tokenizer - -from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( - convert_llama3_to_decilm, -) - - -def test_convert_llama3_config_to_decilm_config(project_root_path: Path, tmp_path: Path): - tokenizer = create_tokenizer(project_root_path) - llama_checkpoint_path = tmp_path / "llama_checkpoint" - create_and_save_small_llama_model( - llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer - ) - - # Convert the Llama model to a DeciLM model - decilm_checkpoint_path = tmp_path / "decilm_checkpoint" - convert_llama3_to_decilm( - input_dir=llama_checkpoint_path, - output_dir=decilm_checkpoint_path, - ) - - # Assert that the converted config has the correct number of block_configs - config_path = decilm_checkpoint_path / "config.json" - assert config_path.exists(), f"Config file not found at {config_path}" - - with open(config_path) as f: - decilm_config = json.load(f) - - # Verify block_configs exists and has the correct length - assert "block_configs" in decilm_config, "block_configs not found in converted config" - actual_num_block_configs = len(decilm_config["block_configs"]) - assert actual_num_block_configs == 2 diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index c409da28be..e2373676d2 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -18,6 +18,7 @@ from functools import partial from pathlib import Path +import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data @@ -27,6 +28,7 @@ from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel +@pytest.mark.skip(reason="Temporarily disabled") def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -41,10 +43,12 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank + project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" ) - hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" + hydra_config_dir = ( + project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" + ) + hydra_config_name = "llama_3_1_8b_instruct" # # Run the mnt.convert() step @@ -83,6 +87,7 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( dist.cleanup() +@pytest.mark.skip(reason="Temporarily disabled") def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -97,10 +102,12 @@ def _test_nas_convert_attn_pruning_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank + project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" + ) + hydra_config_dir = ( + project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" ) - hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" - hydra_config_name = "Llama-3_1-8B-attn-pruning" + hydra_config_name = "llama_3_1_8b_instruct-attn-pruning" # # Run the mnt.convert() step diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index a1258c1d0b..e39f1e1cbc 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -17,6 +17,7 @@ from functools import partial from pathlib import Path +import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data @@ -26,6 +27,7 @@ from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel +@pytest.mark.skip(reason="Temporarily disabled") def test_nas_search(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -40,10 +42,12 @@ def _test_nas_search_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank + project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" ) - hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" + hydra_config_dir = ( + project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" + ) + hydra_config_name = "llama_3_1_8b_instruct" # # Run the mnt.convert() step diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml new file mode 100644 index 0000000000..02c73aca69 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml @@ -0,0 +1,107 @@ +defaults: + - pruning: attn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml new file mode 100644 index 0000000000..65ca64ef4e --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml @@ -0,0 +1,107 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..cad6fcf3ee --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..b24ea1b7cc --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 100 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml new file mode 100644 index 0000000000..9dabef7413 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: conversation +val_dataset_name: train +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json b/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json new file mode 100644 index 0000000000..0bb6fd75b3 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json @@ -0,0 +1,38 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": [ + 128001, + 128008, + 128009 + ], + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.42.3", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json b/tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..02ee80b619 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json @@ -0,0 +1,16 @@ +{ + "bos_token": { + "content": "<|begin_of_text|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "<|eot_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json b/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json new file mode 100644 index 0000000000..83592e2494 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json @@ -0,0 +1,212 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "Split", + "pattern": { + "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + }, + "behavior": "Isolated", + "invert": false + }, + { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": false + } + ] + }, + "post_processor": { + "type": "Sequence", + "processors": [ + { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 1 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + } + ], + "special_tokens": { + "<|begin_of_text|>": { + "id": "<|begin_of_text|>", + "ids": [ + 100 + ], + "tokens": [ + "<|begin_of_text|>" + ] + } + } + } + ] + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "!": 0, + "\"": 1, + "#": 2, + "$": 3, + "%": 4, + "&": 5, + "'": 6, + "(": 7, + ")": 8, + "*": 9, + "+": 10, + ",": 11, + "-": 12, + ".": 13, + "/": 14, + "0": 15, + "1": 16, + "2": 17, + "3": 18, + "4": 19, + "5": 20, + "6": 21, + "7": 22, + "8": 23, + "9": 24, + ":": 25, + ";": 26, + "<": 27, + "=": 28, + ">": 29, + "?": 30, + "@": 31, + "A": 32, + "B": 33, + "C": 34, + "D": 35, + "E": 36, + "F": 37, + "G": 38, + "H": 39, + "I": 40, + "J": 41, + "K": 42, + "L": 43, + "M": 44, + "N": 45, + "O": 46, + "P": 47, + "Q": 48, + "R": 49, + "S": 50, + "T": 51, + "U": 52, + "V": 53, + "W": 54, + "X": 55, + "Y": 56, + "Z": 57, + "[": 58, + "\\": 59, + "]": 60, + "^": 61, + "_": 62, + "`": 63, + "a": 64, + "b": 65, + "c": 66, + "d": 67, + "e": 68, + "f": 69, + "g": 70, + "h": 71, + "i": 72, + "j": 73, + "k": 74, + "l": 75, + "m": 76, + "n": 77, + "o": 78, + "p": 79, + "q": 80, + "r": 81, + "s": 82, + "t": 83, + "u": 84, + "v": 85, + "w": 86, + "x": 87, + "y": 88, + "z": 89, + "{": 90, + "|": 91, + "}": 92, + "~": 93, + "¡": 94, + "¢": 95, + "£": 96, + "¤": 97, + "¥": 98, + "¦": 99, + "<|begin_of_text|>": 100, + "<|eot_id|>": 101 + }, + "merges": [] + } +} diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json b/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..754d9e8db5 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json @@ -0,0 +1,13 @@ +{ + "bos_token": "<|begin_of_text|>", + "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n", + "clean_up_tokenization_spaces": true, + "eos_token": "<|eot_id|>", + "extra_special_tokens": {}, + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 131072, + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py b/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py new file mode 100644 index 0000000000..aedcae4ab2 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script was used to truncate the tokenizer.json file from Llama 3.1 8B model +to keep only the top 100 most common tokens. +""" + +import json + +# Path to your original and new tokenizer.json +in_path = "./tokenizer.json" +out_path = "./tokenizer_truncated.json" + +# How many top tokens to keep +NUM_TO_KEEP = 100 + +with open(in_path, encoding="utf-8") as f: + tokenizer_data = json.load(f) + +# Get and sort the original vocab by index (frequency proxy) +orig_vocab = tokenizer_data["model"]["vocab"] + +# Sort tokens by their original index (lowest index = assumed most common/important) +sorted_tokens = sorted(orig_vocab.items(), key=lambda item: item[1]) + +# Keep the top N tokens +tokens_to_keep = [tok for tok, idx in sorted_tokens[:NUM_TO_KEEP]] + +# Re-index the selected tokens: 0..N-1 +small_vocab = {tok: i for i, tok in enumerate(tokens_to_keep)} +tokenizer_data["model"]["vocab"] = small_vocab + +# Update vocab size +if "vocab_size" in tokenizer_data["model"]: + tokenizer_data["model"]["vocab_size"] = len(small_vocab) + +# Optionally remove merges if present and unneeded (mostly for BPE/WordPiece) +if "merges" in tokenizer_data["model"]: + tokenizer_data["model"]["merges"] = [] + +# Remove added_tokens if not needed +if "added_tokens" in tokenizer_data: + tokenizer_data["added_tokens"] = [] + +# Write out the truncated tokenizer.json +with open(out_path, "w", encoding="utf-8") as f: + json.dump(tokenizer_data, f, indent=2, ensure_ascii=False) + +print(f"Truncated tokenizer saved to: {out_path}") diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index faf72f7495..3a5d9a8cee 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -13,19 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from datetime import timedelta from functools import partial from pathlib import Path +import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron import puzzletron -from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( - convert_llama3_to_decilm, -) +from modelopt.torch.puzzletron.anymodel import convert_model # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) # using a one-click command. @@ -33,91 +32,279 @@ # Note: Bypass is disabled now in the test. -def test_puzzletron(project_root_path: Path, tmp_path: Path): +@pytest.mark.parametrize( + ( + "hf_config_name", + "converter", + "hydra_config_subdir", + "hybrid_override_pattern", + "has_moe_layers", + ), + [ + ("llama_3_1_8b_instruct", "llama", "llama_3_1_8b_instruct", None, False), + # ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), + # ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), + # ( + # "mistral-small-24b-instruct-2501", + # "mistral_small", + # "mistral-small-24b-instruct-2501", + # None, + # False, + # ), + # ("qwen3-8b", "qwen3", "qwen3-8b", None, False), + # ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), + # ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), + # ( + # "nemotron-3-nano-30b-a3b-base-bf16", + # "nemotron_h", + # "nemotron-3-nano-30b-a3b-base-bf16", + # "*E", + # True, + # ), + # ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), + ], +) +def test_puzzletron( + project_root_path: Path, + tmp_path: Path, + hf_config_name: str, + converter: str, + hydra_config_subdir: str, + hybrid_override_pattern: str, + has_moe_layers: bool, +): spawn_multiprocess_job( - size=min(torch.cuda.device_count(), 2), # assertions configured for atmost 2 GPUs - job=partial(_test_puzzletron_multiprocess_job, project_root_path, tmp_path), + size=torch.cuda.device_count(), + job=partial( + _test_puzzletron_multiprocess_job, + project_root_path, + tmp_path, + hf_config_name, + converter, + hydra_config_subdir, + hybrid_override_pattern, + has_moe_layers, + ), backend="nccl", ) def _test_puzzletron_multiprocess_job( - project_root_path: Path, tmp_path: Path, rank: int, size: int + project_root_path: Path, + tmp_path: Path, + hf_config_name: str, + converter: str, + hydra_config_subdir: str, + hybrid_override_pattern: str, + has_moe_layers: bool, + rank: int, + size: int, ): dist.setup(timeout=timedelta(10)) + # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, hf_config_name, hybrid_override_pattern + ) + hydra_config_dir = ( # noqa: F841 + project_root_path / f"tests/gpu/torch/puzzletron/resources/configs/{hydra_config_subdir}" ) - hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" - # Convert the Llama model to DeciLM model. + # Convert the model using AnyModel converter. if rank == 0: - convert_llama3_to_decilm( - input_dir=llama_checkpoint_path, - output_dir=puzzle_dir / "ckpts/teacher", + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=converter, ) dist.barrier() - # Compress the model using a one-click approach - puzzletron.puzzletron( - str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path) - ) + # TODO commented for the duration of merging process from dkorzekwa/any_model to feature/puzzletron + # # Compress the model using a one-click approach + # puzzletron.puzzletron( + # str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) + # ) - # - # Check assertions - # - # assertions for the score_pruning_activations step 1 - _assert_score_pruning_activations(puzzle_dir) - if rank == 0: - # assertions for the pruning_ckpts step 2 - assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + # # + # # Check assertions + # # + # if rank == 0: + # if has_moe_layers: + # # assertions for the score_pruning_activations step 1 (MoE models only) + # rank_filepath = ( + # f"pruning/pruning_scores/expert_removal/10samples_diverse_mini/rank_{rank}.pth" + # ) + # assert (puzzle_dir / rank_filepath).is_file(), f"Expected {rank_filepath} to exist" - # assertions for the build_library_and_stats step 4 + # # assertions for the pruning_ckpts step 2 + # assert (puzzle_dir / "ckpts/num_experts_8").exists() - assert (puzzle_dir / "replacement_library.json").is_file() - assert (puzzle_dir / "subblock_stats.json").is_file() + # # assertions for the mip_and_realize_models step 6 + # # Find the MIP solution directory dynamically (e.g., stats_num_local_experts_*) + # mip_solutions_dir = puzzle_dir / "mip/puzzle_solutions" + # solution_dirs = [ + # d + # for d in mip_solutions_dir.iterdir() + # if d.is_dir() and d.name.startswith("stats_num_local_experts_") + # ] + # assert len(solution_dirs) == 1, ( + # f"Expected exactly one stats_num_local_experts_* directory, found: {[d.name for d in solution_dirs]}" + # ) + # solution_dir = solution_dirs[0] - # assertions for the scoring step 5 - solution_0_filepath = ( - puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" - ) + # solution_0_ckpt_config_path = ( + # solution_dir / "solutions--checkpoints/solution_0/config.json" + # ) + # assert solution_0_ckpt_config_path.exists() + # assert (solution_dir / "solutions.json").exists() - assert solution_0_filepath.exists() + # # Validate lm_loss + # _assert_lm_loss(puzzle_dir, hf_config_name) + # else: + # # assertions for the score_pruning_activations step 1 (FFN pruning) + # _assert_score_pruning_activations(puzzle_dir, hf_config_name) - # assertions for the mip_and_realize_models step 6 - solution_0_ckpt_config_path = ( - puzzle_dir - / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" - ) + # # assertions for the pruning_ckpts step 2 + # assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + # # assertions for the mip_and_realize_models step 6 + # _assert_mip_solutions(puzzle_dir, hf_config_name) - assert solution_0_ckpt_config_path.exists() - assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() + # # assertions for the build_library_and_stats step 4 + # assert (puzzle_dir / "replacement_library.json").is_file() + # assert (puzzle_dir / "subblock_stats.json").is_file() + + # # assertions for the scoring step 5 + # solution_0_filepath = ( + # puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + # ) + # assert solution_0_filepath.exists() dist.cleanup() + print( + f"PYTEST SUMMARY: test_puzzletron({hf_config_name}) test has finished successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +# Expected pruning activation values per model +# Each model has a list of (score, channels) tuples for each FFN layer +EXPECTED_PRUNING_VALUES = { + "llama_3_1_8b_instruct": [ + {"score": 73, "channels": 95}, + {"score": 440, "channels": 174}, + ], + "llama_3_2_3b_instruct": [ + {"score": 79, "channels": 95}, + {"score": 428, "channels": 174}, + ], + "qwen2_5_7b_instruct": [ + {"score": 96, "channels": 433}, + {"score": 485, "channels": 105}, + ], + # Mistral Small 24B + "mistral-small-24b-instruct-2501": [ + {"score": 73, "channels": 95}, + {"score": 431, "channels": 174}, + ], + # Qwen3 8B + "qwen3-8b": [ + {"score": 208, "channels": 51}, + {"score": 475, "channels": 266}, + ], + # NemotronH with pattern "*-" has only 1 FFN layer (the "-" layer) + "nemotron-nano-12b-v2": [ + {"score": 70, "channels": 509}, + ], + # Note: nemotron-3-nano-30b-a3b-base-bf16 uses MoE expert pruning, not FFN pruning + # so it doesn't have EXPECTED_PRUNING_VALUES +} + -def _assert_score_pruning_activations(puzzle_dir: Path): +# Expected lm_loss values per model +EXPECTED_LM_LOSS = { + "llama_3_1_8b_instruct": 4.706878662109375, + "llama_3_2_3b_instruct": 4.816886901855469, + "qwen2_5_7b_instruct": 4.778186798095703, + "nemotron-nano-12b-v2": 4.79390811920166, + "mistral-small-24b-instruct-2501": 4.709150314331055, + "qwen3-8b": 4.733874320983887, + "gpt-oss-20b": 4.689250946044922, + "nemotron-3-nano-30b-a3b-base-bf16": 4.741103172302246, + "qwen3-vl-30b-a3b-instruct": 4.65625, +} + + +def _assert_score_pruning_activations(puzzle_dir: Path, hf_config_name: str): """Assertions for the score_pruning_activations step 1.""" rank = dist.rank() - size = dist.size() rank_filepath = f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" assert (puzzle_dir / rank_filepath).is_file() pruning_scores = torch.load(puzzle_dir / rank_filepath) layer_names = list(pruning_scores.keys()) - assert len(layer_names) == 2 // size - - if size == 1 or rank == 0: - # Check specific values for layer 0 - layer_0 = pruning_scores[layer_names[0]] - assert layer_0["score"][0].item() == 371 - assert layer_0["channels_importance_ascending"][0].item() == 140 - - if size == 1 or rank == 1: - # Check specific values for layer 1 - layer_1 = pruning_scores[layer_names[1 if size == 1 else 0]] - assert layer_1["score"][0].item() == 269 - assert layer_1["channels_importance_ascending"][0].item() == 366 + expected = EXPECTED_PRUNING_VALUES[hf_config_name] + size = dist.size() + + if expected is not None: + # In multi-GPU: layers are distributed across ranks + # Each rank processes len(expected) // size layers + expected_layers_per_rank = len(expected) // size + assert len(layer_names) == expected_layers_per_rank, ( + f"Expected {expected_layers_per_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" + ) + # Check each layer's values + for i, layer_name in enumerate(layer_names): + layer_data = pruning_scores[layer_name] + # Calculate global layer index from rank and local index + global_idx = rank * expected_layers_per_rank + i + assert layer_data["score"][0].item() == expected[global_idx]["score"] + assert ( + layer_data["channels_importance_ascending"][0].item() + == expected[global_idx]["channels"] + ) + else: + # Print values for new models - update EXPECTED_PRUNING_VALUES with these + print(f"\n=== PRUNING VALUES for {hf_config_name} (num_layers={len(layer_names)}) ===") + print(f'"{hf_config_name}": [') + for layer_name in layer_names: + layer_data = pruning_scores[layer_name] + score = layer_data["score"][0].item() + channels = layer_data["channels_importance_ascending"][0].item() + print(f' {{"score": {score}, "channels": {channels}}},') + print("],") + print("===") + + +def _assert_lm_loss(puzzle_dir: Path, hf_config_name: str): + """Validate lm_loss for a model solution.""" + solution_0_path = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + with open(solution_0_path) as f: + validation = json.load(f) + + actual_lm_loss = validation["lm_loss"]["avg"] + expected_lm_loss = EXPECTED_LM_LOSS.get(hf_config_name) + if expected_lm_loss is not None: + assert abs(actual_lm_loss - expected_lm_loss) < 0.01, ( + f"lm_loss mismatch: expected {expected_lm_loss}, got {actual_lm_loss}" + ) + else: + # Print value for new models - update EXPECTED_LM_LOSS with this + print(f"\n=== LM_LOSS for {hf_config_name} ===") + print(f'"{hf_config_name}": {actual_lm_loss},') + print("===") + + +def _assert_mip_solutions(puzzle_dir: Path, hf_config_name: str): + """Assertions for the mip_and_realize_models step.""" + mip_dir = puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB" + + assert (mip_dir / "solutions.json").exists() + assert (mip_dir / "solutions--checkpoints/solution_0/config.json").exists() + + # Validate lm_loss + _assert_lm_loss(puzzle_dir, hf_config_name) From 4dc99325c3dbf93f34a8286f159bd2979b2517a2 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 12 Mar 2026 10:36:08 +0100 Subject: [PATCH 35/62] Draft: anymodel activation scoring (#989) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Merging `anymodel acivation scoring` into `anymodel core` - this MR is only for reviewing. Ultimately `anymodel acivation scoring` should be merged into `feature/puzzletron` once `anymodel core` is merged there. ## Summary by CodeRabbit * **New Features** * Descriptor-driven model handling for broader architecture support * Re-enabled one‑click compression flow in tests * Device selection added for tensor creation * **Bug Fixes** * More robust checkpoint loading, sharding, and post‑load wiring * Improved forward-pass error logging and memory cleanup * **Refactor** * Simplified model validation API (pipeline toggle removed) * JSON encoder now serializes class and unknown object types more gracefully --------- Signed-off-by: Daniel Korzekwa --- .../activation_hooks/utils.py | 121 ++++------- .../score_pruning_activations.py | 2 +- modelopt/torch/puzzletron/puzzletron.py | 26 ++- modelopt/torch/puzzletron/sewing_kit/utils.py | 3 +- .../torch/puzzletron/tools/robust_json.py | 5 + .../tools/sharded_checkpoint_utils.py | 205 +++++++++++++----- .../torch/puzzletron/tools/validate_model.py | 189 +++++++--------- .../utils/validate_runtime_pipeline.py | 94 ++++++-- tests/gpu/torch/puzzletron/test_puzzletron.py | 11 +- 9 files changed, 384 insertions(+), 272 deletions(-) diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py index ab7eed2ac3..1b1485c713 100644 --- a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py @@ -15,84 +15,57 @@ # mypy: ignore-errors """Provides a function to register activation hooks for a model. -Activation hooks are used to compute activation scores for pruning. -""" +Activation hooks are used to compute activation scores for pruning.""" -import re +from typing import Type -from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( - ForwardHook, - IndependentChannelContributionHook, - IndependentKvHeadContributionHook, - IterativeChannelContributionHook, - LayerNormContributionHook, -) -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook as ActivationsHook +from modelopt.torch.puzzletron.tools.logger import aprint def register_activation_hooks( - model: DeciLMForCausalLM, activation_hooks_kwargs: dict -) -> tuple[dict[str, ForwardHook], type[ForwardHook]]: - hook_class_map = { - "mlp.down_proj": { - "independent": IndependentChannelContributionHook, - "iterative": IterativeChannelContributionHook, - }, - "self_attn.o_proj": { - "independent_kv_head_contribution": IndependentKvHeadContributionHook, - }, - r"regex:experts\.\d+\.down_proj$": { # For MoE - "independent": IndependentChannelContributionHook, - }, - # TODO: maybe this is too generic, and we should have it specifically for - # input_layernorm and post_attention_layernorm; now it might select qk_norms - "layernorm": { - "layer_norm_contribution": LayerNormContributionHook, - }, - } - - activation_hooks = {} - target_layer = activation_hooks_kwargs.get("target_layer", "mlp.c_proj") - - if target_layer.startswith("regex:"): - target_layer_regex = target_layer[len("regex:") :] - pattern = re.compile(target_layer_regex) - - def match_predicate(module_name, module): - return pattern.search(module_name) - else: - - def match_predicate(module_name, module): - return module_name.endswith(target_layer) - - target_layer_hooks_map = hook_class_map.get(target_layer) - if target_layer_hooks_map is None: - raise ValueError(f"no hook classes found for: {target_layer}") - - hook_class = target_layer_hooks_map.get(activation_hooks_kwargs["method"]) - if hook_class is None: - raise ValueError(f"Unknown hook class: {hook_class}") - - if target_layer == "block": - pattern = re.compile(r"^transformer\.h\.\d+$") - - def match_predicate(module_name, module): - return pattern.match(module_name) - + model, + activation_hooks_kwargs: dict, + pruning_mixin, + hook_class: Type[ActivationsHook], +) -> dict[str, ActivationsHook]: + """Register activation hooks using the pruning mixin approach. + + Args: + model: The model to register hooks on. + activation_hooks_kwargs: Keyword arguments passed to hook constructors. + pruning_mixin: The pruning mixin that defines which modules to hook. + hook_class: The hook class to instantiate for each module. + + Returns: + Dictionary mapping module names to hook instances. + """ activation_hooks_kwargs["model"] = model - for module_name, module in model.named_modules(): - if match_predicate(module_name, module): - block_config = None - if block_idx_match := re.search(r"\.(\d+)\.", module_name): - block_idx = int(block_idx_match.group(1)) - block_config = model.config.block_configs[block_idx] - curr_activation_hooks_kwargs = { - **activation_hooks_kwargs, - "block_config": block_config, - } - - hook = hook_class(module, curr_activation_hooks_kwargs) - module.register_forward_hook(hook) - activation_hooks[module_name] = hook - return activation_hooks, hook_class + if hook_class not in pruning_mixin.supported_hooks(): + raise ValueError( + f"Hook class not supported for {pruning_mixin.__class__.__name__}, " + f"must be in {pruning_mixin.supported_hooks()}" + ) + + module_names_to_hook = pruning_mixin.get_module_names_to_hook(model) + activation_hooks = dict() + for block_idx, module_name in module_names_to_hook: + block_config = None + if block_idx is not None: + block_config = model.config.block_configs[block_idx] + curr_activation_hooks_kwargs = { + **activation_hooks_kwargs, + "block_config": block_config, + } + + module = model.get_submodule(module_name) + hook = hook_class(module, curr_activation_hooks_kwargs) + module.register_forward_hook(hook) + activation_hooks[module_name] = hook + + if len(activation_hooks) == 0: + raise ValueError("couldn't find any hooks") + + aprint(f"Found the following hooks: {activation_hooks.keys()}") + return activation_hooks diff --git a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py index ef5e5e9ad2..c043c20d5f 100644 --- a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py @@ -138,4 +138,4 @@ def launch_score_activations(cfg: DictConfig): mprint("Starting pruning activation scoring...") # The checkpoint manager inside validate_model handles all progress tracking - validate_model(args=cfg.pruning, pipeline_parallel=True) + validate_model(args=cfg.pruning) diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 1051fdbaf7..0d9ac068f2 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -15,6 +15,7 @@ """This module provides the main compression function for a model using MIP-based NAS search algorithm.""" +import hydra from omegaconf import DictConfig import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations @@ -51,24 +52,25 @@ def puzzletron( f"dataset_path={dataset_path}", ], ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) # Step 1: score_pruning_activations (distributed processing) score_pruning_activations.launch_score_activations(hydra_cfg) - # Step 2: pruning_ckpts (single process) - if dist.is_master(): - pruning_ckpts.launch_prune_ckpt(hydra_cfg) - dist.barrier() + # # Step 2: pruning_ckpts (single process) + # if dist.is_master(): + # pruning_ckpts.launch_prune_ckpt(hydra_cfg) + # dist.barrier() - # Step 4: build_library_and_stats (single process) - if dist.is_master(): - build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - dist.barrier() + # # Step 4: build_library_and_stats (single process) + # if dist.is_master(): + # build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + # dist.barrier() - # Step 5: calc_one_block_scores (distributed processing) - scoring.launch_scoring(hydra_cfg) + # # Step 5: calc_one_block_scores (distributed processing) + # scoring.launch_scoring(hydra_cfg) - # Step 6: mip_and_realize_models (distributed processing) - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + # # Step 6: mip_and_realize_models (distributed processing) + # mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) return hydra_cfg diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 25ee8c9eab..19c1bd6c83 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -291,6 +291,7 @@ def create(cls, data: Tensor) -> MyFakeTensor: def fake_tensor(*args, **kwargs) -> Tensor: dtype: Optional[torch.dtype] = kwargs.get("dtype") use_meta = kwargs.get("use_meta", False) + device = kwargs.get("device", "meta") if len(args) == 1 and isinstance(args[0], Tensor): if use_meta: @@ -298,7 +299,7 @@ def fake_tensor(*args, **kwargs) -> Tensor: else: fake_tensor = MyFakeTensor.create(args[0]) else: - fake_tensor = torch.empty(*args, dtype=dtype, device="meta") + fake_tensor = torch.empty(*args, dtype=dtype, device=device) if not use_meta: fake_tensor = MyFakeTensor.create(fake_tensor) diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py index dbb561b828..3397de6393 100644 --- a/modelopt/torch/puzzletron/tools/robust_json.py +++ b/modelopt/torch/puzzletron/tools/robust_json.py @@ -50,8 +50,13 @@ def default(self, o): # User-defined function in main — fallback to just the name return o.__name__ return f"{o.__module__}.{o.__qualname__}" + if inspect.isclass(o): + return f"{o.__module__}.{o.__qualname__}" if isinstance(o, datetime.timedelta): return str(o) + # Fallback for arbitrary objects: return their class path + if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): + return f"{o.__class__.__module__}.{o.__class__.__qualname__}" return super().default(o) diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py index 1cb5e8489a..1cf02dc931 100644 --- a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -14,22 +14,30 @@ # limitations under the License. # mypy: ignore-errors -"""Provides utilities for distributed loading, saving, and manipulation of +""" +Provides utilities for distributed loading, saving, and manipulation of large language model checkpoints across multiple GPUs/processes. + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. """ import json from collections.abc import Iterable, Mapping from pathlib import Path -from typing import Literal, cast +from types import SimpleNamespace +from typing import Literal, Type, cast import numpy as np import torch import torch.distributed import torch.nn as nn +import transformers +from huggingface_hub import split_torch_state_dict_into_shards from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override @@ -43,23 +51,18 @@ ) from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.dummy_modules import ( + DummyBlock, + DummyLMHead, + DummyModule, + DummyWTE, +) from modelopt.torch.puzzletron.utils.utils import EmptyInitOnDevice -class DummyModule(nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) - - @staticmethod - def load_state_dict_post_hook( - module: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys - ) -> None: - incompatible_keys.missing_keys.clear() - incompatible_keys.unexpected_keys.clear() +class DeciLMDummyBlock(DummyModule): + """Dummy block for DeciLM models (used by replacement_library).""" - -class DummyBlock(DummyModule): def __init__(self, config: DeciLMConfig, block_index: int): super().__init__() self.config = config @@ -73,7 +76,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor | tuple[torc return x, None -class DummyWTE(DummyModule): +class DeciLMDummyWTE(DummyModule): + """Dummy word token embedding for DeciLM models (used by replacement_library).""" + def __init__(self, config: DeciLMConfig, dtype: torch.dtype | None = None): super().__init__() self.n_embd = config.get_hidden_size() @@ -86,7 +91,9 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: return result -class DummyLMHead(DummyModule): +class DeciLMDummyLMHead(DummyModule): + """Dummy LM head for DeciLM models (used by replacement_library).""" + def __init__(self, config: DeciLMConfig): super().__init__() self.vocab_size = config.vocab_size @@ -98,24 +105,44 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return result -def create_local_shard_(model: DeciLMForCausalLM, owned_block_indexes: set[int]): - all_block_indexes = set(range(len(model.model.layers))) +def set_submodule(model: nn.Module, module_name: str, new_submodule: nn.Module) -> None: + """Set a submodule on a model by dotted path.""" + parts = module_name.split(".") + parent_path = ".".join(parts[:-1]) + attr = parts[-1] + parent_module = model.get_submodule(parent_path) if parent_path else model + setattr(parent_module, attr, new_submodule) + + +def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtime): + all_block_indexes = set(range(model.config.num_hidden_layers)) has_first_block = 0 in owned_block_indexes has_last_block = max(all_block_indexes) in owned_block_indexes unowned_block_indexes = all_block_indexes - owned_block_indexes for block_index in unowned_block_indexes: - model.model.layers[block_index] = cast( - "DeciLMDecoderLayer", DummyBlock(model.config, block_index) + decoder_layer_name = descriptor.layer_block_name(block_index) + decoder_layer = model.get_submodule(decoder_layer_name) + set_submodule( + model, + decoder_layer_name, + descriptor.create_dummy_block(decoder_layer, block_index=block_index), ) - if not has_first_block: - model.set_input_embeddings(DummyWTE(model.config)) + # If we have the last block with tied embeddings, keep embed_tokens so lm_head works. + # load_sharded_state_dict will load embed_tokens.weight from the first shard's checkpoint file, + # and since they're tied, lm_head.weight gets populated too. + if not has_first_block and not (has_last_block and model.config.tie_word_embeddings): + set_submodule( + model, + descriptor.input_embedding_name(), + DummyWTE(model.config.hidden_size, dtype=runtime.dtype), + ) if not has_last_block: - model.model.set_final_layer_norm(nn.Identity()) + set_submodule(model, descriptor.final_norm_name(), nn.Identity()) if not (model.config.tie_word_embeddings and has_first_block): - model.set_output_embeddings(DummyLMHead(model.config)) + set_submodule(model, descriptor.output_embedding_name(), DummyLMHead(model.config)) return model @@ -130,42 +157,74 @@ def create_dummy_model( rope_cls = rope_type_to_class[model_config.position_embedding_type] model.model.rotary_emb = rope_cls(config=model.config) - model.model.set_input_embeddings(DummyWTE(model.config, dtype)) + model.model.set_input_embeddings(DeciLMDummyWTE(model.config, dtype)) model.model.set_final_layer_norm(nn.Identity()) - model.set_output_embeddings(DummyLMHead(model.config)) + model.set_output_embeddings(DeciLMDummyLMHead(model.config)) for block_index in range(model_config.get_num_hidden_layers()): - model.model.layers[block_index] = DummyBlock(model.config, block_index) + model.model.layers[block_index] = DeciLMDummyBlock(model.config, block_index) return model +def _get_model_class_from_config(config: PretrainedConfig): + """ + Get the model class from config.architectures field. + Works for any model registered in transformers (CausalLM, VL models, etc.). + Falls back to AutoModelForCausalLM if architectures is not available. + """ + if hasattr(config, "architectures") and config.architectures: + model_class_name = config.architectures[0] + if hasattr(transformers, model_class_name): + return getattr(transformers, model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, falling back to AutoModelForCausalLM" + ) + return AutoModelForCausalLM + + def load_and_shard_model( + descriptor, checkpoint_path: str | Path, owned_block_indexes: set[int] | Literal["auto"] = "auto", - model_config: DeciLMConfig | None = None, - model_config_overrides: Mapping | None = None, - model_dtype: torch.dtype = torch.bfloat16, -) -> DeciLMForCausalLM: + model_config: PretrainedConfig | None = None, +): checkpoint_path = Path(checkpoint_path) - with torch.device(dist.local_rank()): + runtime = SimpleNamespace( + device=torch.device(dist.local_rank()), + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + use_autocast=True, # Default: use autocast; descriptor can override + ) + + with runtime.device: if model_config is None: - model_config = load_model_config( - checkpoint_path, model_config_overrides, ignore_unexpected_config_keys=True - ) + model_config = load_model_config(checkpoint_path) if owned_block_indexes == "auto": owned_block_indexes = set( - np.array_split(np.arange(model_config.get_num_hidden_layers()), dist.size())[ - dist.rank() + np.array_split(np.arange(model_config.num_hidden_layers), runtime.world_size)[ + runtime.global_rank ] ) mprint("Initializing model shards") - model_shard = create_sharded_model( - model_config=model_config, - owned_block_indexes=owned_block_indexes, - ) + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + with deci_x_patcher( + model_descriptor=descriptor, block_configs=getattr(model_config, "block_configs", None) + ): + model_shard = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=model_config, + owned_block_indexes=owned_block_indexes, + ) if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( checkpoint_path / SAFE_WEIGHTS_INDEX_NAME @@ -178,27 +237,47 @@ def load_and_shard_model( shard_state_dict = load_sharded_state_dict( model_name_or_path=str(checkpoint_path), keys_to_load=shard_keys, - device=torch.device(dist.local_rank()), + device=runtime.device, ) new_names = set(shard_state_dict.keys()) mprint(f"{new_names=}") - model_shard.load_state_dict(shard_state_dict, assign=True) + # strict=False: allows missing lm_head.weight when tie_word_embeddings=True (e.g., Llama 3.2 3B) + model_shard.load_state_dict(shard_state_dict, strict=False, assign=True) del shard_state_dict - if model_config.tie_word_embeddings and (0 in owned_block_indexes): - # re-tie the weights in case the connection was severed + # Re-tie weights after load_state_dict with assign=True, which severs the tie. + # Needed on first rank (owns embed_tokens) and last rank (owns lm_head). + has_first_block = 0 in owned_block_indexes + has_last_block = (model_config.num_hidden_layers - 1) in owned_block_indexes + if model_config.tie_word_embeddings and (has_first_block or has_last_block): model_shard.tie_weights() + + # On the last rank with tied embeddings, we kept embed_tokens in create_local_shard_() + # just to load the weight and tie it to lm_head. Now replace it with a dummy so it + # doesn't interfere with the pipeline forward pass (only rank 0 should run embed_tokens). + if model_config.tie_word_embeddings and has_last_block and not has_first_block: + set_submodule( + model_shard, + descriptor.input_embedding_name(), + DummyWTE(model_config.hidden_size, dtype=runtime.dtype), + ) else: mprint("Loading state_dict in main process") - state_dict = load_state_dict(checkpoint_path) if dist.is_master() else None + state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None mprint("Distributing model to shards") load_state_dict_to_shards(model_shard=model_shard, loaded_state_dict=state_dict) del state_dict - model_shard.type(model_dtype) + descriptor.init_rotary_embedding(model_shard, runtime) + + model_shard.type(runtime.dtype) + + # Configure autocast based on model descriptor (some models like Qwen3-VL MoE + # have dtype bugs under autocast) + runtime.use_autocast = descriptor.uses_autocast() params_on_meta_device = [ param_name @@ -206,14 +285,16 @@ def load_and_shard_model( if param.device == torch.device("meta") ] assert len(params_on_meta_device) == 0, ( - f"[global_rank={dist.rank()}] Couldn't load params {params_on_meta_device}" + f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" ) return model_shard def create_sharded_model( - model_config: DeciLMConfig, + runtime, + descriptor, + model_config: PretrainedConfig, owned_block_indexes: set[int], device: str | torch.device | None = "meta", dtype: torch.dtype | None = torch.float32, @@ -224,14 +305,24 @@ def create_sharded_model( dist.barrier() with EmptyInitOnDevice(device="meta", dtype=dtype): - model = DeciLMForCausalLM(model_config) - create_local_shard_(model=model, owned_block_indexes=owned_block_indexes) + # Get model class from config.architectures (works for CausalLM, VL models, etc.) + model_class = _get_model_class_from_config(model_config) + # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() + if model_class is AutoModelForCausalLM: + model = model_class.from_config(model_config, trust_remote_code=True) + else: + model = model_class._from_config(model_config) + create_local_shard_( + model=model, + owned_block_indexes=owned_block_indexes, + descriptor=descriptor, + runtime=runtime, + ) if device != torch.device("meta"): local_shard_state_dict = { k: torch.empty_like(v, device=device) for k, v in model.state_dict().items() } - model.load_state_dict(local_shard_state_dict, assign=True) return model @@ -288,7 +379,9 @@ def load_state_dict_to_shards( def save_sharded_model( model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path ): - """out_path is usually output_checkpoint_path / "model.safetensors" """ + """ + out_path is usually output_checkpoint_path / "model.safetensors" + """ dist.barrier() if isinstance(model_shard, torch.nn.Module): @@ -346,7 +439,9 @@ def load_sharded_state_dict( keys_to_load: Iterable[str] | None = None, device: torch.device | str = "cpu", ) -> dict[str, torch.Tensor]: - """keys_to_load: entire state_dict if None, else partial state_dict containing only these keys""" + """ + keys_to_load: entire state_dict if None, else partial state_dict containing only these keys + """ shard_paths = _resolve_shard_paths(model_name_or_path) # print(f"shard_paths: {shard_paths}") partial_state_dict = {} diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py index 6c3dc3640c..2a5bf22432 100644 --- a/modelopt/torch/puzzletron/tools/validate_model.py +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -12,42 +12,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Provides a function to validate a model. Runs a model forward pass on a dataset and calculates +# mypy: ignore-errors +""" +Provides a function to validate a model. Runs a model forward pass on a dataset and calculates the loss, and optionally registers hooks to capture the inputs and the outputs of pytorch modules that are used for activation scoring for pruning. TODO: Consider moving this a separate module dedicated for scoring + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. """ import textwrap from pathlib import Path +from typing import Type import torch from omegaconf import DictConfig from torch import nn from torch.utils.data import DataLoader -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizerBase, -) +from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.activation_scoring.activation_hooks.utils import ( register_activation_hooks, ) -from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_checkpoint +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import Same from modelopt.torch.puzzletron.tools.logger import aprint, mprint -from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( + load_and_shard_model, + set_submodule, +) from modelopt.torch.puzzletron.utils.data.dataloaders import create_validation_dataloader -from modelopt.torch.puzzletron.utils.parsing import simple_parse_args_string +from modelopt.torch.puzzletron.utils.parsing import ( + simple_parse_args_string, # noqa: F401 (kept for backwards compat) +) from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( HiddenStatesAndLMHead, calculate_losses_pipeline, ) -from modelopt.torch.puzzletron.utils.validation import calculate_losses """ Two goals: @@ -70,7 +77,6 @@ def validate_model( tokenizer: PreTrainedTokenizerBase | None = None, target_hidden_states_per_batch: list[torch.Tensor] | None = None, return_hidden_states: bool = False, - pipeline_parallel: bool = False, calculate_full_score_ablations: bool = False, val_dataloader: DataLoader | None = None, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: @@ -79,53 +85,43 @@ def validate_model( Args: args: Configuration object containing the following attributes: - Model Configuration attributes: - - - ``model_name_or_path`` (str): Path to model checkpoint or HuggingFace model name. - Required unless model is passed directly. - - ``model_dtype`` (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - - ``autocast_dtype`` (str or torch.dtype): Autocast data type for mixed precision. - - Dataset Configuration attributes: - - - ``dataset_path`` (str): Path to the validation dataset. - - ``tokenizer_name`` (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. - - ``data_column`` (str): Column name in dataset containing text data. - - ``block_size`` (int): Maximum sequence length for tokenization. - - ``eval_samples`` (int, optional): Number of samples to evaluate. Uses all if None. - - ``val_dataset_name`` (str): Name of validation dataset split. - - ``source_datasets_to_discard`` (list[str], optional): List of source datasets to exclude. - - ``load_dataset_fn`` (callable, optional): Custom function to load the dataset. - - Data Processing attributes: - - - ``micro_batch_size`` (int): Batch size for evaluation. - - ``seed`` (int): Random seed for reproducibility. - - ``shuffle_seed`` (int, optional): Seed for shuffling data. Uses seed if None. - - ``varlen`` (bool): Enable variable-length sequences. - - ``bos_rate`` (float): Rate of adding BOS token. - - ``fim_rate`` (float): Fill-in-the-middle rate for code completion tasks. - - ``fim_spm_rate`` (float): SPM-based fill-in-the-middle rate. - - Activation Hooks attributes: - - - ``activations_log_dir`` (str, optional): Directory to log activation scores. - If provided, hooks will be registered to capture activations. - - ``activation_hooks_kwargs`` (str or dict, optional): Arguments for activation hooks. - If string, comma-separated format: "arg1=val1,arg2=val2". - - Execution Options attributes: - - - ``calc_losses_on_cpu`` (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. - - ``write_results`` (bool): Write validation results to file. + **Model Configuration:** + - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. Required unless model is passed directly. + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + **Dataset Configuration:** + - dataset_path (str): Path to the validation dataset. + - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. Uses all if None. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + **Data Processing:** + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + **Activation Hooks:** + - activations_log_dir (str, optional): Directory to log activation scores. If provided, hooks will be registered to capture activations. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. If string, comma-separated format: "arg1=val1,arg2=val2". + + **Execution Options:** + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. + - write_results (bool): Write validation results to file. model: Pre-loaded model. If None, will be loaded from args.model_name_or_path. tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args. target_hidden_states_per_batch: Target hidden states for pipeline parallel evaluation. return_hidden_states: Whether to return hidden states from the model. - pipeline_parallel: Enable pipeline parallelism for large models. - calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. - False calculates only a small suite for efficiency. + calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. False calculates only a small suite for efficiency. val_dataloader: Pre-created validation dataloader. If None, will be created from args. Returns: @@ -136,29 +132,31 @@ def validate_model( Returns (None, None) if not on master rank. """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + if val_dataloader is None: val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None validation_full_iters = ( args.eval_samples // args.micro_batch_size ) # model pipeline, single data rank - model = prepare_model(args, model, pipeline_parallel) + model = prepare_model(args, descriptor=descriptor, model=model) just_model_forward = False checkpoint_manager = None activation_hooks = None if args.activations_log_dir is not None: - activation_hooks_kwargs = ( - simple_parse_args_string(args.activation_hooks_kwargs) - if isinstance(args.activation_hooks_kwargs, str) - else args.activation_hooks_kwargs - ) + activation_hooks_kwargs = args.activation_hooks_kwargs or {} activation_hooks_kwargs["validation_full_iters"] = validation_full_iters + hook_class = args.hook_class - # Create activation hooks first - activation_hooks, hook_class = register_activation_hooks( - model=model, activation_hooks_kwargs=activation_hooks_kwargs + # Create activation hooks using pruning mixin + activation_hooks = register_activation_hooks( + model=model, + activation_hooks_kwargs=activation_hooks_kwargs, + hook_class=hook_class, + pruning_mixin=args.pruning_mixin, ) # Create checkpoint manager with hooks @@ -181,26 +179,23 @@ def validate_model( else: mprint("No checkpoint found, starting fresh") just_model_forward = True - model.lm_head = nn.Identity() - - if not pipeline_parallel: - losses, hidden_states_per_batch = calculate_losses( - model=model, - dataloader=val_dataloader, - checkpoint_manager=checkpoint_manager, - ) - else: - losses, hidden_states_per_batch = calculate_losses_pipeline( - stitched_model=model, - dataloader=val_dataloader, - target_hidden_states_per_batch=target_hidden_states_per_batch, - return_hidden_states=return_hidden_states, - calculate_full_score_ablations=calculate_full_score_ablations, - calc_on_cpu=args.calc_losses_on_cpu, - just_model_forward=just_model_forward, - checkpoint_manager=checkpoint_manager, - autocast_dtype=getattr(torch, args.autocast_dtype.strip("torch.")), - ) + set_submodule(model, descriptor.output_embedding_name(), Same()) + + losses, hidden_states_per_batch = calculate_losses_pipeline( + stitched_model=model, + dataloader=val_dataloader, + target_hidden_states_per_batch=target_hidden_states_per_batch, + return_hidden_states=return_hidden_states, + calculate_full_score_ablations=calculate_full_score_ablations, + calc_on_cpu=args.calc_losses_on_cpu, + just_model_forward=just_model_forward, + checkpoint_manager=checkpoint_manager, + autocast_dtype=getattr( + torch, getattr(args, "autocast_dtype", "torch.bfloat16").strip("torch.") + ), + descriptor=descriptor, + use_autocast=descriptor.uses_autocast(), + ) if losses is not None: avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()} @@ -224,31 +219,13 @@ def validate_model( def prepare_model( - args: DictConfig, model: PreTrainedModel | None = None, pipeline_parallel: bool = False + args: DictConfig, + descriptor: Type[ModelDescriptor], + model: PreTrainedModel | None = None, ) -> nn.Module: if model is None: assert args.model_name_or_path is not None - if pipeline_parallel: - model = load_and_shard_model( - args.model_name_or_path, - model_config_overrides={"block_size": args.block_size}, - model_dtype=getattr(torch, args.model_dtype.strip("torch.")), - ) - else: - try: - model = load_checkpoint( - args.model_name_or_path, - model_config_overrides={"block_size": args.block_size}, - ignore_unexpected_config_keys=True, - ) - model.to("cuda") - except FileNotFoundError: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - torch_dtype="auto", - device_map="auto", - trust_remote_code=True, - ) + model = load_and_shard_model(descriptor=descriptor, checkpoint_path=args.model_name_or_path) model.eval() return model diff --git a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py index db1e8f2cea..90fea13c56 100644 --- a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. +""" +Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. Coordinates forward passes and loss computation through model shards distributed across GPUs using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation. @@ -22,16 +23,18 @@ """ # mypy: ignore-errors +import traceback +from contextlib import nullcontext +from typing import Type + import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( - DeciLMForCausalLM, - LMHead, -) +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import LMHead from modelopt.torch.puzzletron.sewing_kit import ( ExternalTarget, InputArgs, @@ -51,6 +54,23 @@ from modelopt.torch.puzzletron.utils.validation import _organize_outputs, calculate_batch_outputs +def _log_forward_error(e: Exception, rank: int, batch_idx: int, num_batches: int) -> None: + """Log detailed error info for distributed forward pass failures. + + When one rank crashes during distributed forward, others may hang waiting for communication. + This logging helps diagnose which rank failed and why. + """ + error_msg = ( + f"\n{'=' * 60}\n" + f"[Rank {rank}] ERROR in stitched_model forward (batch {batch_idx}/{num_batches})\n" + f"Error: {type(e).__name__}: {e}\n" + f"{'=' * 60}\n" + f"{traceback.format_exc()}" + f"{'=' * 60}\n" + ) + print(error_msg, flush=True) + + class HiddenStatesAndLMHead(list): def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): super().__init__(hidden_states) @@ -59,7 +79,7 @@ def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Ten @torch.no_grad() def calculate_losses_pipeline( - stitched_model: StitchedModule | DeciLMForCausalLM, + stitched_model: StitchedModule, dataloader: DataLoader | None, target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, return_hidden_states: bool = False, @@ -68,8 +88,11 @@ def calculate_losses_pipeline( just_model_forward: bool = False, checkpoint_manager=None, autocast_dtype: torch.dtype = torch.bfloat16, + descriptor: Type[ModelDescriptor] = None, + use_autocast: bool = True, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: - """Do model forward on each batch and calculate LM loss. + """ + Do model forward on each batch and calculate LM loss. Optionally also calculate kl_div loss and other metrics from given target_hidden_states_per_batch. Optionally return hidden states per batch. Does not support data-parallel. @@ -87,8 +110,8 @@ def calculate_losses_pipeline( target_hidden_states_per_batch: list[torch.Tensor], returned if return_hidden_states=True """ - if isinstance(stitched_model, DeciLMForCausalLM): - stitched_model = perform_pipeline_stitches(stitched_model) + if not isinstance(stitched_model, StitchedModule): + stitched_model = perform_pipeline_stitches(stitched_model, descriptor) params = list(stitched_model.parameters()) model_device = params[0].device if params else "cpu" @@ -145,14 +168,24 @@ def calculate_losses_pipeline( stitched_model.eval() - with torch.autocast(device_type="cuda", dtype=autocast_dtype): + # Use autocast for mixed precision, or nullcontext if disabled + # (some models like Qwen3-VL MoE have dtype bugs under autocast) + autocast_ctx = ( + torch.autocast(device_type="cuda", dtype=autocast_dtype) if use_autocast else nullcontext() + ) + with autocast_ctx: + fake_input_ids = fake_tensor(1, seq_len, dtype=torch.long, device=model_device) for i_batch in progress_bar: if dist.is_master(): input_ids = all_input_ids[i_batch].to(model_device) else: - input_ids = fake_tensor(1, seq_len, dtype=torch.long) + input_ids = fake_input_ids - output = stitched_model({}, {}, input_ids) + try: + output = stitched_model({}, {}, input_ids) + except Exception as e: + _log_forward_error(e, dist.rank(), i_batch, num_batches) + raise if dist.is_last_process(): logits = output.captured_outputs.get("model_output") @@ -183,6 +216,16 @@ def calculate_losses_pipeline( outputs.append(batch_outputs) + # Free GPU memory after processing each batch + del logits, hidden_states, targets + if target_hidden_states is not None: + del target_hidden_states + if target_logits is not None: + del target_logits + + # Free output tensor memory on all ranks + del output + # Update checkpoint progress periodically if checkpoint_manager: checkpoint_manager.update_progress(i_batch + 1, num_batches) @@ -200,13 +243,28 @@ def calculate_losses_pipeline( return losses, hidden_states_per_batch -def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: +def perform_pipeline_stitches( + model, + descriptor: Type[ModelDescriptor], +) -> StitchedModule: + """Create pipeline stitches for distributed model evaluation. + + Args: + model: The model to stitch (any HuggingFace model with AnyModel descriptor). + descriptor: ModelDescriptor for layer naming. + """ target = ModuleTarget("module", model) stitcher = Needle() + num_layers = model.config.num_hidden_layers + is_real_block = np.flatnonzero( - [not isinstance(block, DummyBlock) for block in model.model.layers] + [ + not isinstance(model.get_submodule(descriptor.layer_block_name(i)), DummyBlock) + for i in range(num_layers) + ] ) + first_block, last_block = is_real_block.min(), is_real_block.max() if dist.rank() != 0: @@ -216,7 +274,7 @@ def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: name="activations", adapter=lambda x: InputArgs(x) ), target.input( - name=f"model.layers.{first_block}", + name=descriptor.layer_block_name(first_block), reducer=InputReducer( lambda acc, override, orig, *args: override + orig.drop_args(0) ), @@ -226,17 +284,17 @@ def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: if not dist.is_last_process(): # send activations to next rank stitcher.stitch( - target.output(f"model.layers.{last_block}"), + target.output(descriptor.layer_block_name(last_block)), RemoteTarget(peer_rank=dist.rank() + 1).value(name="activations"), ) else: # register model output stitcher.stitch( - target.output(name="lm_head"), + target.output(name=descriptor.output_embedding_name()), ExternalTarget().output("model_output"), ) stitcher.stitch( - target.output(name="model.norm"), + target.output(name=descriptor.final_norm_name()), ExternalTarget().output("hidden_states"), ) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index 3a5d9a8cee..585567715b 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -24,6 +24,7 @@ from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron import puzzletron from modelopt.torch.puzzletron.anymodel import convert_model # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) @@ -106,7 +107,7 @@ def _test_puzzletron_multiprocess_job( puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, hf_config_name, hybrid_override_pattern ) - hydra_config_dir = ( # noqa: F841 + hydra_config_dir = ( project_root_path / f"tests/gpu/torch/puzzletron/resources/configs/{hydra_config_subdir}" ) @@ -120,10 +121,10 @@ def _test_puzzletron_multiprocess_job( dist.barrier() # TODO commented for the duration of merging process from dkorzekwa/any_model to feature/puzzletron - # # Compress the model using a one-click approach - # puzzletron.puzzletron( - # str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) - # ) + # Compress the model using a one-click approach + puzzletron.puzzletron( + str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) + ) # # # # Check assertions From d358eb3bf5d614287f785aad9c0a231e2d15c967 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 12 Mar 2026 15:54:52 +0100 Subject: [PATCH 36/62] Draft: Merge anymodel pruning (#990) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Merging dkorzekwa/anymodel_pruning into dkorzekwa/anymodel_acivation scoring - this MR is only for reviewing. Ultimately dkorzekwa/anymodel_pruning should be merged into feature/puzzletron once dkorzekwa/anymodel_acivation scoring is merged there. ## Summary by CodeRabbit * **Refactor** * Moved model initialization to a descriptor-driven flow; accepts dict or JSON overrides and new optional init parameters. Safer handling for vision–language models, improved state creation and checkpoint saving, and clearer profiling output. * **Bug Fixes** * Re-enabled pruning checkpoint step to run on a single process with proper synchronization. * **Documentation** * Clarified and reformatted validation docstrings, expanding configuration and data-processing options. --------- Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/puzzletron.py | 8 +- .../init_child_from_parent.py | 127 +++++++----------- .../torch/puzzletron/tools/validate_model.py | 10 +- 3 files changed, 59 insertions(+), 86 deletions(-) diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 0d9ac068f2..94a1de57ea 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -57,10 +57,10 @@ def puzzletron( # Step 1: score_pruning_activations (distributed processing) score_pruning_activations.launch_score_activations(hydra_cfg) - # # Step 2: pruning_ckpts (single process) - # if dist.is_master(): - # pruning_ckpts.launch_prune_ckpt(hydra_cfg) - # dist.barrier() + # Step 2: pruning_ckpts (single process) + if dist.is_master(): + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() # # Step 4: build_library_and_stats (single process) # if dist.is_master(): diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index 46e403c5f4..36e41c4b6a 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -14,15 +14,22 @@ # limitations under the License. # mypy: ignore-errors -"""TODO Add description""" +"""Initialize child models from parent models using AnyModel approach with deci_x_patcher.""" import json import time +from pathlib import Path +from typing import Optional import torch import yaml +from transformers import AutoModelForCausalLM -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( GQAInitMode, HiddenSizeInitMode, @@ -31,85 +38,37 @@ create_child_state_dict, update_model_config, ) -from modelopt.torch.puzzletron.tools.checkpoint_utils import ( - copy_tokenizer, - load_model_config, - load_state_dict, -) +from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( _save_checkpoint, copy_deci_lm_hf_code, + load_model_config, ) from modelopt.torch.puzzletron.tools.logger import mprint - -""" - -Usage example - remove all/some routed experts: -=============================================== - -PARENT_DIR=".../meta-llama/Llama-4-Scout-17B-16E-Instruct--deci-hf" - -MLP_INIT_MODE="ConcatExpertsIntoDenseFFN" - -## remove all routed experts, turn the shared expert into a dense FFN -# OUTPUT_DIR="/.../micro_scout/Scout-remove-routed-experts" -# MODEL_CONFIG_OVERRIDES_JSON=' -# { -# "ffn": [ -# { -# "moe": null, -# "intermediate_size": 14336, -# "gated": true, -# "hidden_act": "silu" -# } -# ] -# } -# ' - -## concat the shared expert with one routed expert into a dense FFN -OUTPUT_DIR=".../scratch/micro_scout/Scout-ConcatExpertsIntoDenseFFN-concat-shared-and-3-routed" -MODEL_CONFIG_OVERRIDES_JSON=' -{ - "ffn": [ - { - "moe": null, - "intermediate_size": 14336, - "gated": true, - "hidden_act": "silu" - } - ] -} -' - -echo "" -echo "MODEL_CONFIG_OVERRIDES_JSON:" -echo "${MODEL_CONFIG_OVERRIDES_JSON}" - -python -m modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent \ - --parent_checkpoint_dir="$PARENT_DIR" \ - --model_config_overrides_json="$MODEL_CONFIG_OVERRIDES_JSON" \ - --output_checkpoint_dir="$OUTPUT_DIR" \ - --mlp_init_mode="$MLP_INIT_MODE" \ - --mlp_init_config_yaml="$MLP_INIT_CONFIG_YAML" -""" +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config def init_child_from_parent( + descriptor: ModelDescriptor, + pruning_mixin, parent_checkpoint_dir: str, - model_config_overrides_json: str, + model_config_overrides_dict: dict | str, output_checkpoint_dir: str, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, - mlp_init_config_yaml: str | None, + mlp_init_config_yaml: Optional[str], linear_init_mode: LinearInitMode, - hidden_size_init_mode: HiddenSizeInitMode | None = None, - channel_importance_path: str | None = None, - max_workers: int | None = None, # Auto-calculate optimal workers if None - max_layer_workers: int | None = None, # Auto-calculate optimal workers if None + hidden_size_init_mode: Optional[HiddenSizeInitMode] = None, + channel_importance_path: Optional[str] = None, + max_workers: Optional[int] = None, # Auto-calculate optimal workers if None + max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None ) -> None: - """Init child models from parent models in the style of bypass training, + """ + Init child models from parent models in the style of bypass training, but without having to run the entire bypass pipeline. + Uses AnyModel approach with deci_x_patcher for heterogeneous layer configurations. + I/O Optimization Parameters: - max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files)) - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers)) @@ -123,16 +82,16 @@ def init_child_from_parent( "We do not support random init of any subblock in this script to avoid initializing the student model" ) + descriptor = ModelDescriptorFactory.get(descriptor) + copy_tokenizer(parent_checkpoint_dir, output_checkpoint_dir) parent_model_config = load_model_config(parent_checkpoint_dir) parent_state_dict = load_state_dict(parent_checkpoint_dir) - # Parse the model config overrides - if isinstance(model_config_overrides_json, str): - model_config_overrides_dict = json.loads(model_config_overrides_json) - else: - model_config_overrides_dict = model_config_overrides_json + # Parse JSON if string + if isinstance(model_config_overrides_dict, str): + model_config_overrides_dict = json.loads(model_config_overrides_dict) # Separate global config overrides from block-level overrides global_config_overrides = {} @@ -146,7 +105,7 @@ def init_child_from_parent( # Load child model config with global overrides child_model_config = load_model_config( - checkpoint_dir=parent_checkpoint_dir, + parent_checkpoint_dir, model_config_overrides=global_config_overrides, ignore_unexpected_config_keys=True, ) @@ -159,12 +118,23 @@ def init_child_from_parent( ) with torch.device("meta"): - child_model = DeciLMForCausalLM(child_model_config) + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + with deci_x_patcher( + model_descriptor=descriptor, block_configs=child_model_config.block_configs + ): + model_class = _get_model_class_from_config(child_model_config) + # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() + if model_class is AutoModelForCausalLM: + child_model = model_class.from_config(child_model_config, trust_remote_code=True) + else: + child_model = model_class._from_config(child_model_config) + child_state_dict_with_meta_tensors = child_model.state_dict() mlp_init_config = ( yaml.safe_load(mlp_init_config_yaml) - if isinstance(mlp_init_config_yaml, str) is None + if isinstance(mlp_init_config_yaml, str) else mlp_init_config_yaml ) @@ -172,6 +142,8 @@ def init_child_from_parent( mprint("Starting create_child_state_dict...") start_time = time.time() child_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, original_state_dict=parent_state_dict, new_state_dict=child_state_dict_with_meta_tensors, original_config=parent_model_config, @@ -182,7 +154,7 @@ def init_child_from_parent( linear_init_mode=linear_init_mode, hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs, channel_importance_path=channel_importance_path, - max_layer_workers=max_layer_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, ) create_child_state_dict_time = time.time() - start_time mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds") @@ -196,7 +168,8 @@ def init_child_from_parent( child_model_config, child_state_dict, output_checkpoint_dir, - max_workers=max_workers, # Will auto-calculate if None + descriptor, + max_workers=max_workers, ) save_checkpoint_time = time.time() - start_time mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds") @@ -207,7 +180,7 @@ def init_child_from_parent( total_core_time = create_child_state_dict_time + save_checkpoint_time actual_layer_workers = max_layer_workers if max_layer_workers else "auto" actual_io_workers = max_workers if max_workers else "auto" - mprint("\n=== PROFILING SUMMARY ===") + mprint(f"\n=== PROFILING SUMMARY ===") mprint( f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)" ) @@ -216,4 +189,4 @@ def init_child_from_parent( ) mprint(f"Total core processing: {total_core_time:.2f}s") mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") - mprint("=========================\n") + mprint(f"=========================\n") diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py index 2a5bf22432..e68bb0f439 100644 --- a/modelopt/torch/puzzletron/tools/validate_model.py +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -85,12 +85,12 @@ def validate_model( Args: args: Configuration object containing the following attributes: - **Model Configuration:** + Model Configuration: - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. Required unless model is passed directly. - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. - **Dataset Configuration:** + Dataset Configuration: - dataset_path (str): Path to the validation dataset. - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. - data_column (str): Column name in dataset containing text data. @@ -100,7 +100,7 @@ def validate_model( - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. - load_dataset_fn (callable, optional): Custom function to load the dataset. - **Data Processing:** + Data Processing: - micro_batch_size (int): Batch size for evaluation. - seed (int): Random seed for reproducibility. - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. @@ -109,11 +109,11 @@ def validate_model( - fim_rate (float): Fill-in-the-middle rate for code completion tasks. - fim_spm_rate (float): SPM-based fill-in-the-middle rate. - **Activation Hooks:** + Activation Hooks: - activations_log_dir (str, optional): Directory to log activation scores. If provided, hooks will be registered to capture activations. - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. If string, comma-separated format: "arg1=val1,arg2=val2". - **Execution Options:** + Execution Options: - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. - write_results (bool): Write validation results to file. From 8e827f34b682a29361978438daea17aeda38aef2 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 12 Mar 2026 17:51:27 +0100 Subject: [PATCH 37/62] Draft: Merging anymodel:build_library_and_stats (#993) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Merging dkorzekwa/anymodel_build_library_and_stats into dkorzekwa/anymodel_pruning - this MR is only for reviewing. Ultimately dkorzekwa/anymodel_build_library_and_stats should be merged into feature/puzzletron once dkorzekwa/anymodel_pruning scoring is merged there. ## Summary by CodeRabbit * **New Features** * Integrated language-model descriptors to improve compatibility and produce model-specific stats. * **Bug Fixes** * Re‑enabled the single‑process library build and stats workflow step. * **Documentation** * Expanded usage docs and CLI guidance for the replacement library builder. * **Refactor** * Improved typing/annotations and normalized key‑value head handling, affecting attention memory/parameter calculations and textual outputs. --------- Signed-off-by: Daniel Korzekwa --- .../puzzletron/build_library_and_stats.py | 9 +++- modelopt/torch/puzzletron/puzzletron.py | 8 ++-- .../build_replacement_library.py | 29 +++++++++--- .../calc_subblock_params_and_memory.py | 4 +- .../subblock_stats/calc_subblock_stats.py | 45 ++++++++++++++----- .../torch/puzzletron/tools/validate_model.py | 1 - modelopt/torch/puzzletron/utils/utils.py | 33 ++++++-------- 7 files changed, 83 insertions(+), 46 deletions(-) diff --git a/modelopt/torch/puzzletron/build_library_and_stats.py b/modelopt/torch/puzzletron/build_library_and_stats.py index 5f04f60494..31cebdf6be 100644 --- a/modelopt/torch/puzzletron/build_library_and_stats.py +++ b/modelopt/torch/puzzletron/build_library_and_stats.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unified command that runs build_replacement_library followed by calc_subblock_stats. +""" +Unified command that runs build_replacement_library followed by calc_subblock_stats. This script combines the functionality of both commands into a single workflow: 1. First, it builds the replacement library for the puzzle @@ -28,17 +29,21 @@ all the same configuration parameters for both build_replacement_library and calc_subblock_stats. """ +import hydra from omegaconf import DictConfig from modelopt.torch.puzzletron.replacement_library.build_replacement_library import ( launch_build_replacement_library, ) from modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats import launch_calc_subblock_stats +from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.parsing import format_global_config def launch_build_library_and_stats(cfg: DictConfig) -> None: - """Launch both build_replacement_library and calc_subblock_stats in sequence. + """ + Launch both build_replacement_library and calc_subblock_stats in sequence. Args: cfg: Hydra configuration containing settings for both commands diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 94a1de57ea..87d90fdd91 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -62,10 +62,10 @@ def puzzletron( pruning_ckpts.launch_prune_ckpt(hydra_cfg) dist.barrier() - # # Step 4: build_library_and_stats (single process) - # if dist.is_master(): - # build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - # dist.barrier() + # Step 4: build_library_and_stats (single process) + if dist.is_master(): + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + dist.barrier() # # Step 5: calc_one_block_scores (distributed processing) # scoring.launch_scoring(hydra_cfg) diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index 1618aceaf3..0f5ecd2158 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -12,17 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This module constructs the replacement library JSON files from a puzzle directory containing +""" +This module constructs the replacement library JSON files from a puzzle directory containing multiple trained model checkpoints. It analyzes checkpoints to extract unique block and subblock configurations, builds a library of available replacements, and generates solutions for layer replacement in compressed models. The resulting replacement library can then be used by ReplacementLibrary to efficiently load models with mixed teacher/student layers. + +Standard Puzzle Usage: +====================== +python -m modelopt.torch.puzzletron.replacement_library.build_replacement_library PUZZLE_DIR + +Teacher checkpoint dir is assumed to be inside PUZZLE_DIR/ckpts/teacher (symlink is recommended) +though you can supply an explicit --teacher_checkpoint_dir. + +--add_ffn_no_ops and --add_attention_no_ops are optional (default True), + + """ # mypy: ignore-errors import json from pathlib import Path -from typing import Any +from typing import Any, Type import pandas as pd from omegaconf import DictConfig @@ -57,7 +69,8 @@ def build_replacement_library( add_ffn_no_ops: bool = True, add_attention_no_ops: bool = True, ) -> None: - """For normal puzzle runs, use default values. + """ + For normal puzzle runs, use default values. For advanced use cases, see the Usage section. """ master_puzzle_dir = Path(master_puzzle_dir) @@ -90,7 +103,9 @@ def build_replacement_library( def launch_build_replacement_library(cfg: DictConfig) -> None: - """Launch the build replacement library function with Hydra configuration.""" + """ + Launch the build replacement library function with Hydra configuration. + """ mprint(f"Building replacement library for puzzle directory: {cfg.puzzle_dir}") mprint(f"Teacher directory: {cfg.teacher_dir}") mprint( @@ -113,8 +128,8 @@ def infer_teacher_dir( teacher_checkpoint_dir = Path(master_puzzle_dir) / CHECKPOINTS_DIR_NAME / "teacher" if not teacher_checkpoint_dir.exists(): raise ValueError( - "You must either provide the --teacher_checkpoint_dir argument, or create a link to the " - "teacher dir under '{PUZZLE_DIR}/ckpts'." + f"You must either provide the --teacher_checkpoint_dir argument, or create a link to the " + f"teacher dir under '{{PUZZLE_DIR}}/ckpts'." ) teacher_checkpoint_dir = Path(teacher_checkpoint_dir).resolve().absolute() return teacher_checkpoint_dir @@ -362,7 +377,7 @@ def _add_no_op_subblock_rows( def _get_rows_with_no_op_subblock( subblocks_df: pd.DataFrame, no_op_subblock: str -) -> tuple[pd.DataFrame, type[AttentionConfig] | type[FFNConfig]]: +) -> tuple[pd.DataFrame, Type[AttentionConfig] | Type[FFNConfig]]: other_subblock = "ffn" if no_op_subblock == "attention" else "attention" subblock_cls = AttentionConfig if no_op_subblock == "attention" else FFNConfig no_op_subblock_config = subblock_cls(no_op=True) diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py index 2e8630bc98..88081d1773 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py @@ -189,7 +189,7 @@ def calculate_attention_memory( ): seq_len = min(seq_len, attention_chunk_size) - kv_dim = calculate_kv_dim(attention_config.n_heads_in_group, n_head, n_embd) + kv_dim = calculate_kv_dim(attention_config.num_key_value_heads, n_head, n_embd) total_num_tokens = seq_len * (batch_size + prefill_queue_size) kv_cache_size = total_num_tokens * kv_dim query_prefill_size = seq_len * n_embd if allocate_prefill_query else 0 @@ -208,7 +208,7 @@ def calculate_attention_params( n_embd: int, n_head: int, ) -> int: - kv_dim = calculate_kv_dim(attention_config.n_heads_in_group, n_head, n_embd) + kv_dim = calculate_kv_dim(attention_config.num_key_value_heads, n_head, n_embd) return ( n_embd * n_embd * 2 # Wq + Wo + n_embd * kv_dim # Wk + Wv diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index 07597eb5c0..2db0bc3916 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -19,11 +19,10 @@ import dataclasses import json import os -from collections.abc import Iterable from functools import partial from itertools import product from pathlib import Path -from typing import TypeVar +from typing import Iterable, Optional, Type, TypeVar os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -33,6 +32,10 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from tqdm import tqdm +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, BlockConfig, @@ -56,6 +59,15 @@ # Type variable for dataclasses T_DataClass = TypeVar("T_DataClass") +""" +Usage: +python -m modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --benchmark_iterations 1000 ] + +--benchmark_iterations=None (the default) means that the code won't use infery to benchmark runtime, + only memory stats will be calculated. If you want to benchmark runtime, run inside an infery-llm docker. + +""" + def calculate_subblock_stats( calc_subblock_stats_config: DictConfig, @@ -69,7 +81,7 @@ def calculate_subblock_stats( n_embd: int, n_head: int, vocab_size: int, - benchmark_iterations: int | None, + benchmark_iterations: Optional[int], use_cuda_graph: bool, weights_dtype: torch.dtype, activations_dtype: torch.dtype, @@ -181,6 +193,7 @@ def calculate_subblock_stats( ) if is_calc_runtime: + pass # TODO: fix # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \ @@ -206,17 +219,21 @@ def calculate_subblock_stats( def launch_calc_subblock_stats(cfg: DictConfig) -> None: - """Launch the calc subblock stats function with Hydra configuration.""" + """ + Launch the calc subblock stats function with Hydra configuration. + """ mprint(f"Calculating subblock stats for puzzle directory: {cfg.puzzle_dir}") mprint(f"Teacher directory: {cfg.teacher_dir}") mprint( f"Calc subblock stats config: {format_global_config(cfg.calc_subblock_stats, title='Calc subblock stats')}" ) + descriptor = ModelDescriptorFactory.get(cfg.descriptor) calculate_subblock_stats_for_puzzle_dir( cfg.calc_subblock_stats, master_puzzle_dir=cfg.puzzle_dir, teacher_dir=cfg.teacher_dir, + descriptor=descriptor, model_hidden_sizes=cfg.calc_subblock_stats.get("model_hidden_sizes", OmegaConf.create([])), ffn_hidden_sizes=cfg.calc_subblock_stats.get("ffn_hidden_sizes", OmegaConf.create([])), batch_sizes=cfg.calc_subblock_stats.batch_sizes, @@ -224,7 +241,7 @@ def launch_calc_subblock_stats(cfg: DictConfig) -> None: generation_seq_len=cfg.calc_subblock_stats.generation_seq_len, num_active_tokens_override=cfg.calc_subblock_stats.get("num_active_tokens_override", None), prefill_queue_size=cfg.calc_subblock_stats.prefill_queue_size, - allocate_prefill_query=cfg.calc_subblock_stats.allocate_prefill_query, + allocate_prefill_query=cfg.calc_subblock_stats.get("allocate_prefill_query", False), benchmark_iterations=cfg.calc_subblock_stats.get("benchmark_iterations", None), merge_with_existing_stats=cfg.calc_subblock_stats.merge_with_existing_stats, subblock_stats_filename=cfg.calc_subblock_stats.subblock_stats_filename, @@ -236,6 +253,7 @@ def calculate_subblock_stats_for_puzzle_dir( calc_subblock_stats_config: DictConfig, master_puzzle_dir: Path | str, teacher_dir: Path | str, + descriptor: Type[ModelDescriptor], model_hidden_sizes: ListConfig, ffn_hidden_sizes: ListConfig, batch_sizes: Iterable[int] = (1, 8, 16, 32, 64, 128, 256), @@ -268,6 +286,8 @@ def calculate_subblock_stats_for_puzzle_dir( Path(teacher_dir) if teacher_dir is not None else master_puzzle_dir / "ckpts" / "teacher" ) model_config = load_model_config(teacher_dir) + # Get language model config for LM-specific attributes (VL models have nested config) + lm_config = descriptor.get_language_model_config(model_config) subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes, model_config) subblock_stats_file = master_puzzle_dir / subblock_stats_filename @@ -299,7 +319,7 @@ def calculate_subblock_stats_for_puzzle_dir( ] model_hidden_sizes = model_hidden_sizes + [ - model_config.hidden_size + lm_config.hidden_size ] # add a teacher model hidden size for batch_size, ( weights_dtype, @@ -323,8 +343,8 @@ def calculate_subblock_stats_for_puzzle_dir( generation_seq_len=generation_seq_len, prefill_queue_size=prefill_queue_size, n_embd=model_hidden_size, - n_head=model_config.num_attention_heads, - vocab_size=model_config.vocab_size, + n_head=lm_config.num_attention_heads, + vocab_size=lm_config.vocab_size, benchmark_iterations=curr_benchmark_iterations, use_cuda_graph=True, weights_dtype=weights_dtype, @@ -445,7 +465,7 @@ def _load_subblock_configs_from_replacement_library( return subblock_configs -T_DataClass: TypeVar = type[dataclasses.dataclass] +T_DataClass: TypeVar = Type[dataclasses.dataclass] def _dataclass_from_dict( @@ -483,7 +503,7 @@ def add_int8_runtime_estimates(subblock_stats: list[dict]) -> None: if (subblock_config := curr_subblock.get("subblock_config")) is not None: if hasattr(subblock_config, "__dataclass_fields__"): subblock_config = dataclasses.asdict(subblock_config) - is_attention = subblock_config.get("n_heads_in_group", None) is not None + is_attention = subblock_config.get("num_key_value_heads", None) is not None runtime_factor = attention_factor if is_attention else ffn_factor for stat_name, stat_value in bf16_subblock.items(): if "runtime" in stat_name: @@ -512,7 +532,10 @@ def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> di stats for stats in subblock_stats if all( - [stats["args"][key] == corresponding_bf16_args[key] for key in corresponding_bf16_args] + [ + stats["args"][key] == corresponding_bf16_args[key] + for key in corresponding_bf16_args.keys() + ] ) ] if len(matching_bf16_stats) == 0: diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py index e68bb0f439..4a300fcd0b 100644 --- a/modelopt/torch/puzzletron/tools/validate_model.py +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -126,7 +126,6 @@ def validate_model( Returns: A tuple containing: - - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. diff --git a/modelopt/torch/puzzletron/utils/utils.py b/modelopt/torch/puzzletron/utils/utils.py index d56aab0bdb..77a13609aa 100644 --- a/modelopt/torch/puzzletron/utils/utils.py +++ b/modelopt/torch/puzzletron/utils/utils.py @@ -28,24 +28,21 @@ ) -def calculate_kv_dim(n_heads_in_group: int, n_head: int, n_embd: int) -> int: +def calculate_kv_dim(num_key_value_heads: int, n_head: int, n_embd: int) -> int: """Calculate the key-value dimension for grouped-query attention. - TODO: Consider a better place for this function. - Args: - n_heads_in_group: Number of attention heads per key-value group. + num_key_value_heads: Number of key-value heads. n_head: Total number of attention heads. n_embd: Embedding dimension. Returns: - Combined dimension for key and value tensors (2 * n_kv_heads * head_size). + Combined dimension for key and value tensors (2 * num_key_value_heads * head_size). """ - if n_heads_in_group is None: + if num_key_value_heads is None: return 0 - n_kv_heads = n_head // n_heads_in_group head_size = n_embd // n_head - kv_dim = 2 * n_kv_heads * head_size + kv_dim = 2 * num_key_value_heads * head_size return kv_dim @@ -53,7 +50,6 @@ def raise_unknown_subblock_config_error(subblock_config: Any) -> None: """Raise an error for invalid subblock configuration types. TODO: Consider a better place for this function. - Args: subblock_config: The invalid subblock configuration object. @@ -69,7 +65,6 @@ def sizeof_dtype(dtype: torch.dtype) -> int | float: """Return the size in bytes of the given data type. TODO: Consider a better place for this function. - Args: dtype: PyTorch data type or custom type string (e.g., 'nvfp4'). @@ -125,10 +120,10 @@ def solution_to_str(block_configs: list[dict[str, Any] | BlockConfig]) -> str: def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: - """Convert a BlockConfig to a human-readable string representation. + """ + Convert a BlockConfig to a human-readable string representation. TODO: Consider a better place for this function. - Args: block_config: BlockConfig dataclass or dict containing attention and ffn configs. @@ -153,7 +148,6 @@ def subblock_config_to_str( """Convert a subblock config (FFN, Attention, Mamba, or MoE) to string. TODO: Consider a better place for this function. - Args: subblock_config: FFNConfig, AttentionConfig dataclass or dict. subblock_name: Name of subblock ('ffn', 'attention', 'mamba', 'moe'). @@ -161,7 +155,7 @@ def subblock_config_to_str( Returns: Formatted string showing subblock type and key parameters (e.g., intermediate_size, - n_heads_in_group), or None if input is None. + num_key_value_heads), or None if input is None. """ if subblock_config is None: return None @@ -194,8 +188,8 @@ def subblock_config_to_str( intermediate_size = subblock_config["intermediate_size"] rep += f" intermediate_{intermediate_size}".ljust(8) elif subblock_name == "attention": - n_heads_in_group = subblock_config["n_heads_in_group"] - rep += f" gqa_{n_heads_in_group}".ljust(8) + num_key_value_heads = subblock_config["num_key_value_heads"] + rep += f" kv_heads_{num_key_value_heads}".ljust(8) elif subblock_name == "mamba": mamba_num_heads = subblock_config["mamba"]["num_heads"] mamba_head_dim = subblock_config["mamba"]["head_dim"] @@ -216,7 +210,8 @@ def subblock_config_to_str( class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): def __init__(self, device=None, dtype=None): - """Create tensors with given device and dtype and don't run initialization + """ + Create tensors with given device and dtype and don't run initialization (but instead use "empty tensors", i.e. uninitialized memory). device: `torch.device` to work with @@ -225,8 +220,8 @@ def __init__(self, device=None, dtype=None): Example:: with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): model = LLaMA(model_config) - model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth")) - """ + model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth"))""" + self.device = device self.dtype = dtype From eb4b210ee94927ffca43416cc680ae2b6d8b73d4 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 13 Mar 2026 00:37:25 +0100 Subject: [PATCH 38/62] Draft: merge any model calc one block scores (#994) ### What does this PR do? Merging dkorzekwa/any_model_calc_one_block_scores into dkorzekwa/anymodel_build_library_and_stats - this MR is only for reviewing. Ultimately dkorzekwa/any_model_calc_one_block_scores should be merged into feature/puzzletron once dkorzekwa/anymodel_build_library_and_stats is merged there. ## Summary by CodeRabbit ## Release Notes * **New Features** * Enabled distributed scoring step in puzzletron workflow for improved performance. * **Improvements** * Extended model loading support to accommodate additional model formats and configurations. * Optimized GPU memory management during validation with improved cache synchronization. * Streamlined validation function signatures for easier usage. --------- Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/puzzletron.py | 4 +- .../replacement_library.py | 103 ++++++++---- ...validate_puzzle_with_multi_replacements.py | 153 ++++++++++-------- .../puzzletron/tools/validation_utils.py | 10 +- 4 files changed, 162 insertions(+), 108 deletions(-) diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 87d90fdd91..262df76489 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -67,8 +67,8 @@ def puzzletron( build_library_and_stats.launch_build_library_and_stats(hydra_cfg) dist.barrier() - # # Step 5: calc_one_block_scores (distributed processing) - # scoring.launch_scoring(hydra_cfg) + # Step 5: calc_one_block_scores (distributed processing) + scoring.launch_scoring(hydra_cfg) # # Step 6: mip_and_realize_models (distributed processing) # mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py index bf6cc66362..7935fea4a0 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -12,23 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Replacement library for efficiently loading and managing layer-replaced DeciLM models. +""" +Replacement library for efficiently loading and managing layer-replaced DeciLM models. - Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations """ # mypy: ignore-errors +import copy import json import re +import tempfile from pathlib import Path +from typing import List, Optional -import numpy as np import torch from immutabledict import immutabledict from lru import LRU +from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from torch import nn +from transformers import PretrainedConfig, PreTrainedModel import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.converter.converter import Converter from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, @@ -51,9 +57,11 @@ init_module_with_state_dict, load_model_config, ) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( create_dummy_model, is_in_safetensors_format, + load_and_shard_model, load_sharded_state_dict, ) @@ -62,8 +70,10 @@ class ReplacementLibrary: def __init__( self, replacement_library_path: str | Path, - model_config_overrides: dict | None = None, + descriptor, + model_config_overrides: Optional[dict] = None, ): + self.descriptor = descriptor self.replacement_library = self._load_replacement_library(replacement_library_path) self._ensure_all_checkpoints_are_split() self.model_config_overrides = ( @@ -114,42 +124,77 @@ def n_layer(self) -> int: def model_config(self) -> DeciLMConfig: if self._model_config is None: self._model_config = load_model_config( - self.get_arbitrary_checkpoint_dir(), self.model_config_overrides + self.get_arbitrary_checkpoint_dir(), + self.model_config_overrides, + ignore_unexpected_config_keys=True, ) return self._model_config def create_model_config(self, layer_replacements: list[dict]): block_configs, _ = extract_block_configs_and_locations(layer_replacements) - model_config = self.model_config.set_block_configs(block_configs) + model_config = copy.deepcopy(self.model_config) + model_config.block_configs = block_configs + model_config.num_hidden_layers = len(block_configs) return model_config - def load_model(self, layer_replacements: list[dict]) -> DeciLMForCausalLM: - block_configs, block_locations = extract_block_configs_and_locations(layer_replacements) - model_config = self.model_config.set_block_configs(block_configs) + def _get_arbitrary_block_checkpoint_paths(self): + checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) + subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name] + return non_block_paths + + def create_index_file_from_weights(self, weight_paths: List[str]): + weight_map = {} + for weight_path in weight_paths: + weight_path = Path(weight_path) + with safe_open(str(weight_path), framework="pt", device="cpu") as f: + for tensor_name in f.keys(): + weight_map[tensor_name] = f"{SAFETENSORS_SUBBLOCKS_DIR_NAME}/{weight_path.name}" + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + return index + + def prepare_tmp_checkpoint_dir( + self, + tmpdir: Path, + model_config: PretrainedConfig, + layer_replacements: List[dict], + ): + arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) - owned_block_indexes = _get_owned_block_indexes(model_config.get_num_hidden_layers()) - model = create_dummy_model(model_config, self.dtype) + weight_paths = self._get_arbitrary_block_checkpoint_paths() + for layer_replacement in layer_replacements: + weight_paths += layer_replacement["weight_paths"] - is_first_shard = 0 in owned_block_indexes - if is_first_shard and not isinstance(model.model.get_input_embeddings(), nn.Embedding): - model.set_input_embeddings(self.get_embedding()) + weights_index = self.create_index_file_from_weights(weight_paths) + index_path = tmpdir / "model.safetensors.index.json" + with index_path.open("w", encoding="utf-8") as out: + json.dump(weights_index, out, indent=2, sort_keys=True) - is_last_shard = model_config.get_num_hidden_layers() - 1 in owned_block_indexes - if is_last_shard and not isinstance(model.model.get_output_embeddings(), nn.Linear): - model.model.set_final_layer_norm(self.get_ln_f()) - model.set_output_embeddings(self.get_lm_head()) + Converter.copy_checkpoint_files(arbitrary_checkpoint_dir, tmpdir) + save_model_config(model_config, tmpdir) - active_blocks = [] - for block_idx in owned_block_indexes: - layer_replacement, block_idx_in_replacement = block_locations[block_idx] - block = self.get_block(layer_replacement, block_idx_in_replacement) - model.model.layers[block_idx] = block - active_blocks.append(block) + # create symlinks inside tmpdir + subblocks_dir = tmpdir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_dir.mkdir(exist_ok=True) + for weight_path in weight_paths: + link_path = subblocks_dir / weight_path.name + link_path.symlink_to(weight_path) - self._move_inactive_blocks_to_cpu(active_blocks) + def load_model( + self, + layer_replacements: list[dict], + ) -> PreTrainedModel: + """Load model using AnyModel approach with temporary checkpoint directory.""" + model_config = self.create_model_config(layer_replacements) + with tempfile.TemporaryDirectory(prefix="replacement_solution_") as tmpdir: + tmpdir = Path(tmpdir) + self.prepare_tmp_checkpoint_dir( + tmpdir, model_config=model_config, layer_replacements=layer_replacements + ) + model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir) return model - def load_checkpoint(self, checkpoint_dir: str | Path) -> DeciLMForCausalLM: + def load_checkpoint(self, checkpoint_dir: str | Path) -> PreTrainedModel: checkpoint_dir = Path(checkpoint_dir).resolve() layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) model = self.load_model(layer_replacements) @@ -221,7 +266,7 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: if len(state_dict) > 0: block_indices = [ int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0]) - for param_name in state_dict + for param_name in state_dict.keys() ] assert sorted(set(block_indices)) == list( range(min(block_indices), max(block_indices) + 1) @@ -239,7 +284,9 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: } dtype = infer_weights_dtype(state_dict) - model_config = self.model_config.set_block_configs(layer_replacement["child_block_configs"]) + model_config = copy.deepcopy(self.model_config) + model_config.block_configs = layer_replacement["child_block_configs"] + model_config.num_hidden_layers = len(layer_replacement["child_block_configs"]) module_list = nn.ModuleList( [ @@ -316,7 +363,7 @@ def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor: partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name]) return partial_state_dict[param_name] - non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / "non_block.pth" + non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth" assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir) non_block_state_dict = torch.load(non_block_pth_path) return non_block_state_dict[param_name] diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index 4e3266df4f..d253c94457 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -21,9 +21,11 @@ # mypy: ignore-errors import json +import shutil import warnings from functools import partial from pathlib import Path +from typing import Optional import torch from omegaconf import DictConfig @@ -31,6 +33,8 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.converter import Converter +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.replacement_library.replacement_library import ReplacementLibrary from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement @@ -40,15 +44,15 @@ copy_tokenizer, ) from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( - copy_deci_lm_hf_code, save_checkpoint, save_safetensors_index, ) +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model from modelopt.torch.puzzletron.tools.validation_utils import ( validate_model_and_extract_hidden_states, validate_model_with_teacher_similarity_metrics, ) -from modelopt.torch.puzzletron.utils.parsing import get_nested_key +from modelopt.torch.puzzletron.utils.parsing import get_nested_key, parse_path from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import perform_pipeline_stitches """ @@ -68,62 +72,55 @@ def validate_puzzle_solutions(args: DictConfig) -> None: Args: args: Configuration object containing the following attributes: - Puzzle Configuration (Required) attributes: - - - ``replacement_library_path`` (Path): Path to the replacement library JSON file. - - ``solutions_path`` (Path): Path to puzzle solutions JSON file or directory containing solution files. - - ``solutions_to_validate`` (list[int], optional): Indices of specific solutions to validate. - Validates all solutions if None. - - ``sort_solutions_by`` (str, optional): JSON field path to sort solutions by before validation. - - ``bigger_is_better`` (bool): If True, sort solutions in descending order. Used with sort_solutions_by. - - ``skip_validation`` (bool): If True, skip model validation and only save models if requested. - - ``save_models`` (bool): If True, save realized model checkpoints for each solution. - - Teacher/Tokenizer Configuration attributes: - - - ``teacher_dir`` (Path, optional): Path to teacher model directory. Auto-inferred if not provided. - - ``tokenizer_name`` (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. - - Model Configuration (Required if skip_validation=False) attributes: - - - ``model_dtype`` (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - - ``autocast_dtype`` (str or torch.dtype): Autocast data type for mixed precision. - - Dataset Configuration (Required if skip_validation=False) attributes: - - - ``dataset_path`` (str): Path to the validation dataset. - - ``data_column`` (str): Column name in dataset containing text data. - - ``block_size`` (int): Maximum sequence length for tokenization. - - ``eval_samples`` (int, optional): Number of samples to evaluate. - - ``val_dataset_name`` (str): Name of validation dataset split. - - ``source_datasets_to_discard`` (list[str], optional): List of source datasets to exclude. - - ``load_dataset_fn`` (callable, optional): Custom function to load the dataset. - - Data Processing (Required if skip_validation=False) attributes: - - - ``micro_batch_size`` (int): Batch size for evaluation. - - ``seed`` (int): Random seed for reproducibility. - - ``shuffle_seed`` (int, optional): Seed for shuffling data. - - ``varlen`` (bool): Enable variable-length sequences. - - ``bos_rate`` (float): Rate of adding BOS token. - - ``fim_rate`` (float): Fill-in-the-middle rate for code completion tasks. - - ``fim_spm_rate`` (float): SPM-based fill-in-the-middle rate. - - Output Configuration attributes: - - - ``output_dir`` (Path, optional): Directory to save validation results. - Auto-generated from solutions_path if not provided. - - Execution Options (Optional if skip_validation=False) attributes: - - - ``calc_losses_on_cpu`` (bool): Calculate losses on CPU to avoid OOM. - - ``write_results`` (bool): Write validation results to file. - - ``activations_log_dir`` (str, optional): Directory to log activation scores. - - ``activation_hooks_kwargs`` (str or dict, optional): Arguments for activation hooks. + Puzzle Configuration (Required): + - replacement_library_path (Path): Path to the replacement library JSON file. + - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. + - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. Validates all solutions if None. + - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. + - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. + - skip_validation (bool): If True, skip model validation and only save models if requested. + - save_models (bool): If True, save realized model checkpoints for each solution. + + Teacher/Tokenizer Configuration: + - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided. + - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. + + Model Configuration (Required if skip_validation=False): + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration (Required if skip_validation=False): + - dataset_path (str): Path to the validation dataset. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing (Required if skip_validation=False): + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Output Configuration: + - output_dir (Path, optional): Directory to save validation results. Auto-generated from solutions_path if not provided. + + Execution Options (Optional if skip_validation=False): + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. + - write_results (bool): Write validation results to file. + - activations_log_dir (str, optional): Directory to log activation scores. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. Returns: None. Saves validation results and optionally model checkpoints to disk. """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + puzzle_solutions = load_puzzle_solutions( args.solutions_path, args.sort_solutions_by, args.bigger_is_better ) @@ -143,29 +140,41 @@ def validate_puzzle_solutions(args: DictConfig) -> None: else args.solutions_path.with_name(f"{args.solutions_path.stem}--validation") ) - replacement_library = ReplacementLibrary(args.replacement_library_path) + replacement_library = ReplacementLibrary( + args.replacement_library_path, + descriptor=descriptor, + model_config_overrides={"use_cache": False}, + ) teacher_hidden_states = None if (args.teacher_dir is not None) and (not args.skip_validation): - teacher_model = replacement_library.load_checkpoint(args.teacher_dir) + teacher_model = load_and_shard_model( + checkpoint_path=args.teacher_dir, descriptor=descriptor + ) teacher_model.cuda(dist.local_rank()) - stitched_model = perform_pipeline_stitches(teacher_model) + stitched_model = perform_pipeline_stitches(teacher_model, descriptor=descriptor) teacher_hidden_states = validate_model_and_extract_hidden_states( args, stitched_model, tokenizer, output_dir, model_name="teacher", - pipeline_parallel=True, val_dataloader=val_dataloader, ) + # Properly release CUDA memory after teacher validation + teacher_model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + dist.barrier() + for i_solution, puzzle_solution in tqdm( list(zip(args.solutions_to_validate, puzzle_solutions)), desc="Validating solutions" ): layer_replacements = _extract_layer_replacements_from_puzzle_solution(puzzle_solution) - # realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) - realizable_as_symlinks = False + realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) + # realizable_as_symlinks = False model_config = replacement_library.create_model_config(layer_replacements) if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): model = replacement_library.load_model(layer_replacements) @@ -177,24 +186,21 @@ def validate_puzzle_solutions(args: DictConfig) -> None: / f"solution_{i_solution}" ) - model_config.dtype = args.model_dtype - model_config.architectures = ["DeciLMForCausalLM"] + model_config.dtype = getattr(args, "model_dtype", "torch.bfloat16") + Converter.copy_checkpoint_files(args.teacher_dir, checkpoint_dir) if realizable_as_symlinks: if dist.is_master(): - save_checkpoint_as_symlinks( - layer_replacements, model_config, checkpoint_dir, replacement_library - ) - else: - save_checkpoint(model, checkpoint_dir) + # save_checkpoint_as_symlinks is currently not supported + pass + save_checkpoint(model, checkpoint_dir, descriptor) copy_tokenizer(args.tokenizer_name, checkpoint_dir) - copy_deci_lm_hf_code(checkpoint_dir) dist.barrier() if not args.skip_validation: model.cuda(dist.local_rank()) - stitched_model = perform_pipeline_stitches(model) + stitched_model = perform_pipeline_stitches(model, descriptor=descriptor) validate_model_with_teacher_similarity_metrics( args, stitched_model, @@ -203,10 +209,15 @@ def validate_puzzle_solutions(args: DictConfig) -> None: output_dir, model_name=f"solution_{i_solution}", extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution}, - pipeline_parallel=True, val_dataloader=val_dataloader, ) + # Properly release CUDA memory after solution validation + model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + dist.barrier() @@ -278,7 +289,7 @@ def _extract_layer_replacements_from_puzzle_solution( def load_puzzle_solutions( solutions_path: Path, - sort_solutions_by: str | None, + sort_solutions_by: Optional[str], bigger_is_better: bool, ) -> list[dict]: assert solutions_path.exists(), f"{solutions_path=} does not exist" diff --git a/modelopt/torch/puzzletron/tools/validation_utils.py b/modelopt/torch/puzzletron/tools/validation_utils.py index 697977cdaf..d7197e8abf 100644 --- a/modelopt/torch/puzzletron/tools/validation_utils.py +++ b/modelopt/torch/puzzletron/tools/validation_utils.py @@ -21,7 +21,7 @@ # mypy: ignore-errors from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional, Union import torch from omegaconf import DictConfig, OmegaConf @@ -44,8 +44,7 @@ def validate_model_and_extract_hidden_states( tokenizer: PreTrainedTokenizerBase, output_dir: str | Path, model_name: str, - extra_payload: dict[str, Any] | None = None, - pipeline_parallel: bool = False, + extra_payload: Optional[dict[str, Any]] = None, val_dataloader=None, ) -> list[torch.Tensor | LowMemorySparseTensor]: mprint(f""" @@ -60,7 +59,6 @@ def validate_model_and_extract_hidden_states( model, tokenizer, return_hidden_states=True, - pipeline_parallel=pipeline_parallel, val_dataloader=val_dataloader, ) if dist.is_last_process(): @@ -77,8 +75,7 @@ def validate_model_with_teacher_similarity_metrics( target_hidden_states_per_batch: list[torch.Tensor], output_dir: str | Path, model_name: str, - extra_payload: dict[str, Any] | None = None, - pipeline_parallel: bool = False, + extra_payload: Optional[dict[str, Any]] = None, calculate_full_score_ablations: bool = False, val_dataloader=None, ) -> None: @@ -95,7 +92,6 @@ def validate_model_with_teacher_similarity_metrics( model, tokenizer, target_hidden_states_per_batch=target_hidden_states_per_batch, - pipeline_parallel=pipeline_parallel, calculate_full_score_ablations=calculate_full_score_ablations, val_dataloader=val_dataloader, ) From 8fe318d46436927cbae53cdffb9319b9b9a88f0d Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 13 Mar 2026 10:22:55 +0100 Subject: [PATCH 39/62] Draft: merge any_model: mip_and_realize_models (#995) ### What does this PR do? Merging dkorzekwa/mip_and_realize_models into dkorzekwa/any_model_calc_one_block_scores - this MR is only for reviewing. Ultimately dkorzekwa/mip_and_realize_models should be merged into feature/puzzletron once dkorzekwa/any_model_calc_one_block_scores is merged there. ## Summary by CodeRabbit ## Release Notes * **New Features** * Enabled model realization step during compression workflow after scoring phase completes. * **Bug Fixes** * Fixed key-value head calculation in attention configuration sourcing. * **Tests** * Strengthened validation checks for compression artifacts and output directories; added rank-aware assertions for model compression expectations. * **Chores** * Minor documentation formatting updates. --------- Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/mip/run_puzzle.py | 4 +- modelopt/torch/puzzletron/puzzletron.py | 4 +- tests/gpu/torch/puzzletron/test_puzzletron.py | 109 +++++++++--------- 3 files changed, 57 insertions(+), 60 deletions(-) diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py index 72919d27cd..da0f90452d 100644 --- a/modelopt/torch/puzzletron/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -688,9 +688,7 @@ def _get_block_stats( not (block_config.attention.no_op and block_config.ffn.no_op) ) block_stats["num_kv_heads"] = ( - subblock_stats["args"]["n_head"] // block_config.attention.n_heads_in_group - if block_stats["has_attention"] - else 0 + block_config.attention.num_key_value_heads if block_stats["has_attention"] else 0 ) block_stats["num_local_experts"] = ( block_config.ffn.moe.num_local_experts if block_stats["has_moe"] else 0 diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 262df76489..5a1484e07a 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -70,7 +70,7 @@ def puzzletron( # Step 5: calc_one_block_scores (distributed processing) scoring.launch_scoring(hydra_cfg) - # # Step 6: mip_and_realize_models (distributed processing) - # mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + # Step 6: mip_and_realize_models (distributed processing) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) return hydra_cfg diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index 585567715b..a42a716547 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -120,66 +120,65 @@ def _test_puzzletron_multiprocess_job( ) dist.barrier() - # TODO commented for the duration of merging process from dkorzekwa/any_model to feature/puzzletron # Compress the model using a one-click approach puzzletron.puzzletron( str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) ) - # # - # # Check assertions - # # - # if rank == 0: - # if has_moe_layers: - # # assertions for the score_pruning_activations step 1 (MoE models only) - # rank_filepath = ( - # f"pruning/pruning_scores/expert_removal/10samples_diverse_mini/rank_{rank}.pth" - # ) - # assert (puzzle_dir / rank_filepath).is_file(), f"Expected {rank_filepath} to exist" - - # # assertions for the pruning_ckpts step 2 - # assert (puzzle_dir / "ckpts/num_experts_8").exists() - - # # assertions for the mip_and_realize_models step 6 - # # Find the MIP solution directory dynamically (e.g., stats_num_local_experts_*) - # mip_solutions_dir = puzzle_dir / "mip/puzzle_solutions" - # solution_dirs = [ - # d - # for d in mip_solutions_dir.iterdir() - # if d.is_dir() and d.name.startswith("stats_num_local_experts_") - # ] - # assert len(solution_dirs) == 1, ( - # f"Expected exactly one stats_num_local_experts_* directory, found: {[d.name for d in solution_dirs]}" - # ) - # solution_dir = solution_dirs[0] - - # solution_0_ckpt_config_path = ( - # solution_dir / "solutions--checkpoints/solution_0/config.json" - # ) - # assert solution_0_ckpt_config_path.exists() - # assert (solution_dir / "solutions.json").exists() - - # # Validate lm_loss - # _assert_lm_loss(puzzle_dir, hf_config_name) - # else: - # # assertions for the score_pruning_activations step 1 (FFN pruning) - # _assert_score_pruning_activations(puzzle_dir, hf_config_name) - - # # assertions for the pruning_ckpts step 2 - # assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() - - # # assertions for the mip_and_realize_models step 6 - # _assert_mip_solutions(puzzle_dir, hf_config_name) - - # # assertions for the build_library_and_stats step 4 - # assert (puzzle_dir / "replacement_library.json").is_file() - # assert (puzzle_dir / "subblock_stats.json").is_file() - - # # assertions for the scoring step 5 - # solution_0_filepath = ( - # puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" - # ) - # assert solution_0_filepath.exists() + # + # Check assertions + # + if rank == 0: + if has_moe_layers: + # assertions for the score_pruning_activations step 1 (MoE models only) + rank_filepath = ( + f"pruning/pruning_scores/expert_removal/10samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file(), f"Expected {rank_filepath} to exist" + + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/num_experts_8").exists() + + # assertions for the mip_and_realize_models step 6 + # Find the MIP solution directory dynamically (e.g., stats_num_local_experts_*) + mip_solutions_dir = puzzle_dir / "mip/puzzle_solutions" + solution_dirs = [ + d + for d in mip_solutions_dir.iterdir() + if d.is_dir() and d.name.startswith("stats_num_local_experts_") + ] + assert len(solution_dirs) == 1, ( + f"Expected exactly one stats_num_local_experts_* directory, found: {[d.name for d in solution_dirs]}" + ) + solution_dir = solution_dirs[0] + + solution_0_ckpt_config_path = ( + solution_dir / "solutions--checkpoints/solution_0/config.json" + ) + assert solution_0_ckpt_config_path.exists() + assert (solution_dir / "solutions.json").exists() + + # Validate lm_loss + _assert_lm_loss(puzzle_dir, hf_config_name) + else: + # assertions for the score_pruning_activations step 1 (FFN pruning) + _assert_score_pruning_activations(puzzle_dir, hf_config_name) + + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + # assertions for the mip_and_realize_models step 6 + _assert_mip_solutions(puzzle_dir, hf_config_name) + + # assertions for the build_library_and_stats step 4 + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() + + # assertions for the scoring step 5 + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + assert solution_0_filepath.exists() dist.cleanup() From 2fbdf0e3ef3a4a0b2b10187ba85b5e12813165a1 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Fri, 13 Mar 2026 12:16:04 -0700 Subject: [PATCH 40/62] Update uv.lock for nspect puzzletron scanning Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- uv.lock | 256 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 251 insertions(+), 5 deletions(-) diff --git a/uv.lock b/uv.lock index 5931592e7d..c974d6af83 100644 --- a/uv.lock +++ b/uv.lock @@ -147,6 +147,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.9.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" } + [[package]] name = "anyio" version = "4.12.1" @@ -230,6 +236,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/9e/5faefbf9db1db466d633735faceda1f94aa99ce506ac450d232536266b32/cachetools-7.0.1-py3-none-any.whl", hash = "sha256:8f086515c254d5664ae2146d14fc7f65c9a4bce75152eb247e5a9c5e6d7b2ecf", size = 13484, upload-time = "2026-02-10T22:24:03.741Z" }, ] +[[package]] +name = "cbcbox" +version = "2.924" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/86/acd4af8ab4b00dbffac101bb2c87f2d84bdf3e2f3fa79171582c91770178/cbcbox-2.924-py3-none-macosx_15_0_arm64.whl", hash = "sha256:5ba1be40e761c47cbf6e94783f89be60e280d62d049f8f11c69c50fe5719a0ad", size = 30115522, upload-time = "2026-03-12T02:44:27.26Z" }, + { url = "https://files.pythonhosted.org/packages/2f/1c/3d528eb20a94db16c01e14d1f3e307fad73210c35fbd1dfd41b4214b4d64/cbcbox-2.924-py3-none-macosx_15_0_x86_64.whl", hash = "sha256:7f05a5c81c39e94ba32f5230ce9a4c93a899329941fa06f290a538af6121c4cc", size = 59928608, upload-time = "2026-03-12T02:44:30.814Z" }, + { url = "https://files.pythonhosted.org/packages/18/4b/a6ea7c4f600c071a4f6e653054a2172d315269229610d0efb9afcf67af77/cbcbox-2.924-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6842d5a646d650bad77ceddffd89f09a64d111f047b4744ce2e01361f352efed", size = 35904686, upload-time = "2026-03-12T02:44:34.347Z" }, + { url = "https://files.pythonhosted.org/packages/eb/96/9c3d681116a9df29273b48454aec4961b715e6ea01cbeea6ac3636c106d1/cbcbox-2.924-py3-none-manylinux2014_x86_64.whl", hash = "sha256:283c37212a63d2af55ed618653b45d3e66561c8f58e2c5542e3a323e400c6dc3", size = 72713045, upload-time = "2026-03-12T02:44:38.623Z" }, + { url = "https://files.pythonhosted.org/packages/de/29/59500c00eed48de52242889ac94fc6e9ad5fd780646d78847913274fe9e2/cbcbox-2.924-py3-none-win_amd64.whl", hash = "sha256:e5d0b308f89e56bba50286506417fceee0ef085ecf0eda30ed7beb5fea3abb4f", size = 57577749, upload-time = "2026-03-12T02:44:43.133Z" }, +] + [[package]] name = "certifi" version = "2026.2.25" @@ -239,6 +257,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/3c/c17fb3ca2d9c3acff52e30b309f538586f9f5b9c9cf454f3845fc9af4881/certifi-2026.2.25-py3-none-any.whl", hash = "sha256:027692e4402ad994f1c42e52a4997a9763c646b73e4096e4d5d6db8af1d6f0fa", size = 153684, upload-time = "2026-02-25T02:54:15.766Z" }, ] +[[package]] +name = "cffi" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pycparser", marker = "implementation_name != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/d7/516d984057745a6cd96575eea814fe1edd6646ee6efd552fb7b0921dec83/cffi-2.0.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44", size = 184283, upload-time = "2025-09-08T23:22:08.01Z" }, + { url = "https://files.pythonhosted.org/packages/9e/84/ad6a0b408daa859246f57c03efd28e5dd1b33c21737c2db84cae8c237aa5/cffi-2.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49", size = 180504, upload-time = "2025-09-08T23:22:10.637Z" }, + { url = "https://files.pythonhosted.org/packages/50/bd/b1a6362b80628111e6653c961f987faa55262b4002fcec42308cad1db680/cffi-2.0.0-cp310-cp310-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c", size = 208811, upload-time = "2025-09-08T23:22:12.267Z" }, + { url = "https://files.pythonhosted.org/packages/4f/27/6933a8b2562d7bd1fb595074cf99cc81fc3789f6a6c05cdabb46284a3188/cffi-2.0.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb", size = 216402, upload-time = "2025-09-08T23:22:13.455Z" }, + { url = "https://files.pythonhosted.org/packages/05/eb/b86f2a2645b62adcfff53b0dd97e8dfafb5c8aa864bd0d9a2c2049a0d551/cffi-2.0.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0", size = 203217, upload-time = "2025-09-08T23:22:14.596Z" }, + { url = "https://files.pythonhosted.org/packages/9f/e0/6cbe77a53acf5acc7c08cc186c9928864bd7c005f9efd0d126884858a5fe/cffi-2.0.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4", size = 203079, upload-time = "2025-09-08T23:22:15.769Z" }, + { url = "https://files.pythonhosted.org/packages/98/29/9b366e70e243eb3d14a5cb488dfd3a0b6b2f1fb001a203f653b93ccfac88/cffi-2.0.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453", size = 216475, upload-time = "2025-09-08T23:22:17.427Z" }, + { url = "https://files.pythonhosted.org/packages/21/7a/13b24e70d2f90a322f2900c5d8e1f14fa7e2a6b3332b7309ba7b2ba51a5a/cffi-2.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495", size = 218829, upload-time = "2025-09-08T23:22:19.069Z" }, + { url = "https://files.pythonhosted.org/packages/60/99/c9dc110974c59cc981b1f5b66e1d8af8af764e00f0293266824d9c4254bc/cffi-2.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5", size = 211211, upload-time = "2025-09-08T23:22:20.588Z" }, + { url = "https://files.pythonhosted.org/packages/49/72/ff2d12dbf21aca1b32a40ed792ee6b40f6dc3a9cf1644bd7ef6e95e0ac5e/cffi-2.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb", size = 218036, upload-time = "2025-09-08T23:22:22.143Z" }, + { url = "https://files.pythonhosted.org/packages/e2/cc/027d7fb82e58c48ea717149b03bcadcbdc293553edb283af792bd4bcbb3f/cffi-2.0.0-cp310-cp310-win32.whl", hash = "sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a", size = 172184, upload-time = "2025-09-08T23:22:23.328Z" }, + { url = "https://files.pythonhosted.org/packages/33/fa/072dd15ae27fbb4e06b437eb6e944e75b068deb09e2a2826039e49ee2045/cffi-2.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739", size = 182790, upload-time = "2025-09-08T23:22:24.752Z" }, + { url = "https://files.pythonhosted.org/packages/12/4a/3dfd5f7850cbf0d06dc84ba9aa00db766b52ca38d8b86e3a38314d52498c/cffi-2.0.0-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe", size = 184344, upload-time = "2025-09-08T23:22:26.456Z" }, + { url = "https://files.pythonhosted.org/packages/4f/8b/f0e4c441227ba756aafbe78f117485b25bb26b1c059d01f137fa6d14896b/cffi-2.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c", size = 180560, upload-time = "2025-09-08T23:22:28.197Z" }, + { url = "https://files.pythonhosted.org/packages/b1/b7/1200d354378ef52ec227395d95c2576330fd22a869f7a70e88e1447eb234/cffi-2.0.0-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92", size = 209613, upload-time = "2025-09-08T23:22:29.475Z" }, + { url = "https://files.pythonhosted.org/packages/b8/56/6033f5e86e8cc9bb629f0077ba71679508bdf54a9a5e112a3c0b91870332/cffi-2.0.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93", size = 216476, upload-time = "2025-09-08T23:22:31.063Z" }, + { url = "https://files.pythonhosted.org/packages/dc/7f/55fecd70f7ece178db2f26128ec41430d8720f2d12ca97bf8f0a628207d5/cffi-2.0.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5", size = 203374, upload-time = "2025-09-08T23:22:32.507Z" }, + { url = "https://files.pythonhosted.org/packages/84/ef/a7b77c8bdc0f77adc3b46888f1ad54be8f3b7821697a7b89126e829e676a/cffi-2.0.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664", size = 202597, upload-time = "2025-09-08T23:22:34.132Z" }, + { url = "https://files.pythonhosted.org/packages/d7/91/500d892b2bf36529a75b77958edfcd5ad8e2ce4064ce2ecfeab2125d72d1/cffi-2.0.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26", size = 215574, upload-time = "2025-09-08T23:22:35.443Z" }, + { url = "https://files.pythonhosted.org/packages/44/64/58f6255b62b101093d5df22dcb752596066c7e89dd725e0afaed242a61be/cffi-2.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9", size = 218971, upload-time = "2025-09-08T23:22:36.805Z" }, + { url = "https://files.pythonhosted.org/packages/ab/49/fa72cebe2fd8a55fbe14956f9970fe8eb1ac59e5df042f603ef7c8ba0adc/cffi-2.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414", size = 211972, upload-time = "2025-09-08T23:22:38.436Z" }, + { url = "https://files.pythonhosted.org/packages/0b/28/dd0967a76aab36731b6ebfe64dec4e981aff7e0608f60c2d46b46982607d/cffi-2.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743", size = 217078, upload-time = "2025-09-08T23:22:39.776Z" }, + { url = "https://files.pythonhosted.org/packages/2b/c0/015b25184413d7ab0a410775fdb4a50fca20f5589b5dab1dbbfa3baad8ce/cffi-2.0.0-cp311-cp311-win32.whl", hash = "sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5", size = 172076, upload-time = "2025-09-08T23:22:40.95Z" }, + { url = "https://files.pythonhosted.org/packages/ae/8f/dc5531155e7070361eb1b7e4c1a9d896d0cb21c49f807a6c03fd63fc877e/cffi-2.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5", size = 182820, upload-time = "2025-09-08T23:22:42.463Z" }, + { url = "https://files.pythonhosted.org/packages/95/5c/1b493356429f9aecfd56bc171285a4c4ac8697f76e9bbbbb105e537853a1/cffi-2.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d", size = 177635, upload-time = "2025-09-08T23:22:43.623Z" }, + { url = "https://files.pythonhosted.org/packages/ea/47/4f61023ea636104d4f16ab488e268b93008c3d0bb76893b1b31db1f96802/cffi-2.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d", size = 185271, upload-time = "2025-09-08T23:22:44.795Z" }, + { url = "https://files.pythonhosted.org/packages/df/a2/781b623f57358e360d62cdd7a8c681f074a71d445418a776eef0aadb4ab4/cffi-2.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c", size = 181048, upload-time = "2025-09-08T23:22:45.938Z" }, + { url = "https://files.pythonhosted.org/packages/ff/df/a4f0fbd47331ceeba3d37c2e51e9dfc9722498becbeec2bd8bc856c9538a/cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe", size = 212529, upload-time = "2025-09-08T23:22:47.349Z" }, + { url = "https://files.pythonhosted.org/packages/d5/72/12b5f8d3865bf0f87cf1404d8c374e7487dcf097a1c91c436e72e6badd83/cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062", size = 220097, upload-time = "2025-09-08T23:22:48.677Z" }, + { url = "https://files.pythonhosted.org/packages/c2/95/7a135d52a50dfa7c882ab0ac17e8dc11cec9d55d2c18dda414c051c5e69e/cffi-2.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e", size = 207983, upload-time = "2025-09-08T23:22:50.06Z" }, + { url = "https://files.pythonhosted.org/packages/3a/c8/15cb9ada8895957ea171c62dc78ff3e99159ee7adb13c0123c001a2546c1/cffi-2.0.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037", size = 206519, upload-time = "2025-09-08T23:22:51.364Z" }, + { url = "https://files.pythonhosted.org/packages/78/2d/7fa73dfa841b5ac06c7b8855cfc18622132e365f5b81d02230333ff26e9e/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba", size = 219572, upload-time = "2025-09-08T23:22:52.902Z" }, + { url = "https://files.pythonhosted.org/packages/07/e0/267e57e387b4ca276b90f0434ff88b2c2241ad72b16d31836adddfd6031b/cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94", size = 222963, upload-time = "2025-09-08T23:22:54.518Z" }, + { url = "https://files.pythonhosted.org/packages/b6/75/1f2747525e06f53efbd878f4d03bac5b859cbc11c633d0fb81432d98a795/cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187", size = 221361, upload-time = "2025-09-08T23:22:55.867Z" }, + { url = "https://files.pythonhosted.org/packages/7b/2b/2b6435f76bfeb6bbf055596976da087377ede68df465419d192acf00c437/cffi-2.0.0-cp312-cp312-win32.whl", hash = "sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18", size = 172932, upload-time = "2025-09-08T23:22:57.188Z" }, + { url = "https://files.pythonhosted.org/packages/f8/ed/13bd4418627013bec4ed6e54283b1959cf6db888048c7cf4b4c3b5b36002/cffi-2.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5", size = 183557, upload-time = "2025-09-08T23:22:58.351Z" }, + { url = "https://files.pythonhosted.org/packages/95/31/9f7f93ad2f8eff1dbc1c3656d7ca5bfd8fb52c9d786b4dcf19b2d02217fa/cffi-2.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6", size = 177762, upload-time = "2025-09-08T23:22:59.668Z" }, +] + [[package]] name = "cfgv" version = "3.5.0" @@ -561,6 +627,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/0f/5d0c71a1aefeb08efff26272149e07ab922b64f46c63363756224bd6872e/filelock-3.24.3-py3-none-any.whl", hash = "sha256:426e9a4660391f7f8a810d71b0555bce9008b0a1cc342ab1f6947d37639e002d", size = 24331, upload-time = "2026-02-19T00:48:18.465Z" }, ] +[[package]] +name = "fire" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "termcolor" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/00/f8d10588d2019d6d6452653def1ee807353b21983db48550318424b5ff18/fire-0.7.1.tar.gz", hash = "sha256:3b208f05c736de98fb343310d090dcc4d8c78b2a89ea4f32b837c586270a9cbf", size = 88720, upload-time = "2025-08-16T20:20:24.175Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/4c/93d0f85318da65923e4b91c1c2ff03d8a458cbefebe3bc612a6693c7906d/fire-0.7.1-py3-none-any.whl", hash = "sha256:e43fd8a5033a9001e7e2973bab96070694b9f12f2e0ecf96d4683971b5ab1882", size = 115945, upload-time = "2025-08-16T20:20:22.87Z" }, +] + [[package]] name = "flatbuffers" version = "25.12.19" @@ -733,6 +811,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794, upload-time = "2021-09-17T21:40:39.897Z" }, ] +[[package]] +name = "hydra-core" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "omegaconf" }, + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/8e/07e42bc434a847154083b315779b0a81d567154504624e181caf2c71cd98/hydra-core-1.3.2.tar.gz", hash = "sha256:8a878ed67216997c3e9d88a8e72e7b4767e81af37afb4ea3334b269a4390a824", size = 3263494, upload-time = "2023-02-23T18:33:43.03Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b", size = 154547, upload-time = "2023-02-23T18:33:40.801Z" }, +] + [[package]] name = "identify" version = "2.6.16" @@ -760,6 +852,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", size = 8769, upload-time = "2022-07-01T12:21:02.467Z" }, ] +[[package]] +name = "immutabledict" +version = "4.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/e6/718471048fea0366c3e3d1df3acfd914ca66d571cdffcf6d37bbcd725708/immutabledict-4.3.1.tar.gz", hash = "sha256:f844a669106cfdc73f47b1a9da003782fb17dc955a54c80972e0d93d1c63c514", size = 7806, upload-time = "2026-02-15T10:32:34.668Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/ce/f9018bf69ae91b273b6391a095e7c93fa5e1617f25b6ba81ad4b20c9df10/immutabledict-4.3.1-py3-none-any.whl", hash = "sha256:c9facdc0ff30fdb8e35bd16532026cac472a549e182c94fa201b51b25e4bf7bf", size = 5000, upload-time = "2026-02-15T10:32:33.672Z" }, +] + [[package]] name = "importlib-metadata" version = "8.7.1" @@ -842,6 +943,57 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/34/ba8980c3ee52db6e93d110d12e0e7908f863143da107cfd6c557b459e29c/lief-0.17.4-cp312-cp312-win_arm64.whl", hash = "sha256:b839bb150080122a7a5bd17a958a2b6f5add96ff535eff0c692570bcd257d0b3", size = 3460387, upload-time = "2026-02-21T09:30:27.889Z" }, ] +[[package]] +name = "lru-dict" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/06/0a/dec86efe38b350314c49a8d39ef01ba7cf8bbbef1d177646320eedea7159/lru_dict-1.4.1.tar.gz", hash = "sha256:cc518ff2d38cc7a8ab56f9a6ae557f91e2e1524b57ed8e598e97f45a2bd708fc", size = 13439, upload-time = "2025-11-02T10:02:13.548Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/6c/396716746ca46fd2ac52a7a6cbd7b4cf848e5d430f431dacd209290dfa71/lru_dict-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3766e397aa6de1ca3442729bc1fa75834ab7b0a6b017e6e197d3a66b61abde59", size = 16757, upload-time = "2025-11-02T10:00:55.767Z" }, + { url = "https://files.pythonhosted.org/packages/2d/93/c163ffb71beb18f18459461658fd16c8b8c86aed858f2dc7c7e636318f61/lru_dict-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658e152d3a4ad0e1d75e6f53b1fa353779539920b38be99f4ea33d3bad41efdb", size = 11243, upload-time = "2025-11-02T10:00:56.715Z" }, + { url = "https://files.pythonhosted.org/packages/44/e3/fa96d54032531c67eeacf0ab6f56e10e05f25d382a29f6a381ac8ecf3814/lru_dict-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:98af7044b5c3d85a649e1afb8891829ff5210caf9143acc741b3e98ab1b66ff6", size = 11726, upload-time = "2025-11-02T10:00:57.377Z" }, + { url = "https://files.pythonhosted.org/packages/7a/23/bae4f32fb014fd2dc5512e9267a3b1ec34c3b55d16a2202a1193d9ae635d/lru_dict-1.4.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:906d99705b79a00b5668bdb8782ad823ccc8d26e1fc6b56327ae469a8d12e9b4", size = 29823, upload-time = "2025-11-02T10:00:58.34Z" }, + { url = "https://files.pythonhosted.org/packages/9f/3b/8c3d1e6a188ce65e0161b86dbd18f2290950baf1e9e28e4948fc123d9a67/lru_dict-1.4.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:885643fd968336d8652fddb0778184e2eeff7b7aebced6de268af6d6caef42d5", size = 30812, upload-time = "2025-11-02T10:00:59.358Z" }, + { url = "https://files.pythonhosted.org/packages/ed/11/7f061507eda944150ed959e99a3700ce6358c1241c7f697b2f1ade48646b/lru_dict-1.4.1-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:24c779334bed82f1a7eb2d1ebcba2b7aa9a1555d40a3b53e05eb6b9dfcb0609c", size = 32480, upload-time = "2025-11-02T10:01:00.141Z" }, + { url = "https://files.pythonhosted.org/packages/75/e7/94ac30d33c6f8a8eca5d7e81c0ce26fb7b79b18ea65accdcb2a652b19abc/lru_dict-1.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c6099e2ecb118dfeae4a197bfcc702ea5841bfd86f19d1b340e932d0f5c47c10", size = 30199, upload-time = "2025-11-02T10:01:01.31Z" }, + { url = "https://files.pythonhosted.org/packages/4a/81/c93ee7365db67dfb497e6218aa0395b9ec878c07c732d348bfbd651bcc95/lru_dict-1.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4e0db4f3105108598749550e639b283b07df0bb91cac3b47e86ffebcab721cc7", size = 31489, upload-time = "2025-11-02T10:01:02.363Z" }, + { url = "https://files.pythonhosted.org/packages/9f/0b/634e8b4eca2497647f802bbe1ae3f0e1e9a0de1d555cf77c022527b2682f/lru_dict-1.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e21f67ba374d1945051b547e719d44a8c7880718f67a15a03e7a12e1d12ea96b", size = 29522, upload-time = "2025-11-02T10:01:03.399Z" }, + { url = "https://files.pythonhosted.org/packages/de/cc/591b959d77cc0e0ac016f11baf26d03d566bb88a53fa9b41e157bc68bc4b/lru_dict-1.4.1-cp310-cp310-win32.whl", hash = "sha256:f309b4018dd41f33bf3bd4cc0f62421da8bcca513ea044dbb22f3cd029935012", size = 13066, upload-time = "2025-11-02T10:01:04.457Z" }, + { url = "https://files.pythonhosted.org/packages/d9/bc/c14b67fdbdb5a2a81cfb907ea8a8b0c9da5aed899f34921ebf097e22a966/lru_dict-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:e84cd1065955897de01f1fb4cbd6f87cab7706e920283bb98c27341d76dd9a8d", size = 14008, upload-time = "2025-11-02T10:01:05.421Z" }, + { url = "https://files.pythonhosted.org/packages/4c/ff/1d02bc444174f07d3ce747568989969c97dc77d0513f4c3b8b6224cb976f/lru_dict-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cc74c49cf1c26d6c28d8f6988cf0354696ca38a4f6012fa63055d2800791784b", size = 16760, upload-time = "2025-11-02T10:01:06.492Z" }, + { url = "https://files.pythonhosted.org/packages/0b/d8/e2e970272ea5fe7ba6349a5e7d0bb0fd814f5d1b88a53bc72b8c2a5e034f/lru_dict-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0158db85dfb2cd2fd2ddaa47709bdb073f814e0a8a149051b70b07e59ac83231", size = 11249, upload-time = "2025-11-02T10:01:07.261Z" }, + { url = "https://files.pythonhosted.org/packages/a5/26/860b5e60f339f8038118028388926224c8b70779e8243d68772e0e0d0ab3/lru_dict-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c8ac5cfd56e036bd8d7199626147044485fa64a163a5bde96bfa5a1c7fea2273", size = 11728, upload-time = "2025-11-02T10:01:08.185Z" }, + { url = "https://files.pythonhosted.org/packages/61/55/fc8f71953fd343ede33810b0a000b4130e03635ae09b28569e45735ded2f/lru_dict-1.4.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2eb2058cb7b329b4b72baee4cd1bb322af1feec73de79e68edb35d333c90b698", size = 30795, upload-time = "2025-11-02T10:01:08.862Z" }, + { url = "https://files.pythonhosted.org/packages/4c/26/ad549550e6a236818a91434570d38d7a93824b0410d3db1c845a53238e1f/lru_dict-1.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ffbb6f3c1e906e92d9129c14a88d81358be1e0b60195c1729b215a52e9670de", size = 31807, upload-time = "2025-11-02T10:01:09.581Z" }, + { url = "https://files.pythonhosted.org/packages/7c/39/72dae9ac0e95a8576a45e3bd62a6fc3e7dbb116794efa1337c7b450d4836/lru_dict-1.4.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:11b289d78a48a086846e46d2275707d33523f5d543475336c29c56fd5d0e65dc", size = 33437, upload-time = "2025-11-02T10:01:10.676Z" }, + { url = "https://files.pythonhosted.org/packages/a8/46/221479834703a5397fa32f07212ace38f104a31ad1af8a921cf25e053677/lru_dict-1.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3fe10c1f45712e191eecb2a69604d566c64ddfe01136fd467c890ed558c3ad40", size = 31168, upload-time = "2025-11-02T10:01:11.47Z" }, + { url = "https://files.pythonhosted.org/packages/6e/13/98d36e2522fda7f6625c15332562f81f1465161a5ae021d9b3b408f8c427/lru_dict-1.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e04820e3473bd7f55440f24c946ca4335e392d5e3e0e1e948020e94cd1954372", size = 32454, upload-time = "2025-11-02T10:01:12.522Z" }, + { url = "https://files.pythonhosted.org/packages/49/18/345ff2a98d27cddae40c84cf0466fcc329f3965cd21322bb561a94e4d332/lru_dict-1.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:edc004c88911a8f9715e716116d2520c13db89afd6c37cc0f28042ba10635163", size = 30574, upload-time = "2025-11-02T10:01:13.293Z" }, + { url = "https://files.pythonhosted.org/packages/d7/92/dfea71402a7ca46332bcb854827ee68bbc9be205e2558c3a40293eca9782/lru_dict-1.4.1-cp311-cp311-win32.whl", hash = "sha256:b0b5360264b37676c405ea0a560744d7dcb2d47adff1e7837113c15fabcc7a71", size = 13031, upload-time = "2025-11-02T10:01:13.96Z" }, + { url = "https://files.pythonhosted.org/packages/3a/7b/4c7d566d77ec3ad9128f07407494c2aec57909f8dd59f0c9910bd4c05840/lru_dict-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:bb4b37daad9fe4e796c462f4876cf34e52564630902bdf59a271bc482b48a361", size = 14007, upload-time = "2025-11-02T10:01:14.857Z" }, + { url = "https://files.pythonhosted.org/packages/4f/a8/89e4c26e0e751321b41b0a3007384f97d9eae7a863c49af1c68c43005ca3/lru_dict-1.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:7fa342c6e6bc811ee6a17eb569d37b149340d5aa5a637a53438e316a95783838", size = 16683, upload-time = "2025-11-02T10:01:15.891Z" }, + { url = "https://files.pythonhosted.org/packages/f1/34/b3c6fdd120af68b6eeb524d0de3293ff27918ec57f45eed6bef1789fd085/lru_dict-1.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bd86bd202a7c1585d9dc7e5b0c3d52cf76dc56b261b4bbecfeefbbae31a5c97d", size = 11216, upload-time = "2025-11-02T10:01:16.867Z" }, + { url = "https://files.pythonhosted.org/packages/e9/7e/280267ae23f1ec1074ddaab787c5e041e090220e8e37828d51ff4e681dfd/lru_dict-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4617554f3e42a8f520c8494842c23b98f5b7f4d5e0410e91a4c3ad0ea5f7e094", size = 11687, upload-time = "2025-11-02T10:01:17.485Z" }, + { url = "https://files.pythonhosted.org/packages/ca/18/fec42416ceff98ae2760067ec72b0b9fc02840e729bbc18059c6a02cb01f/lru_dict-1.4.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:40927a6a4284d437047f547e652b15f6f0f40210deb6b9e5b77e556ff0faea0f", size = 31960, upload-time = "2025-11-02T10:01:18.158Z" }, + { url = "https://files.pythonhosted.org/packages/c2/ef/38e7ee1a5d32b9b1629d045fa5a495375383aacfb2945f4d9535b9af9630/lru_dict-1.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2c07ecb6d42494e45d00c2541e6b0ae7659fc3cf89681521ba94b15c682d4fe", size = 32882, upload-time = "2025-11-02T10:01:18.841Z" }, + { url = "https://files.pythonhosted.org/packages/72/82/d56653ca144c291ab37bea5f23c5078ffbe64f7f5b466f91d400590b9106/lru_dict-1.4.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:85b28aa2de7c5f1f6c68221857accd084438df98edbd4f57595795734225770c", size = 34268, upload-time = "2025-11-02T10:01:19.564Z" }, + { url = "https://files.pythonhosted.org/packages/94/ae/382651533d60f0b598757efda56dc87cad5ac311fba8e61f86fb916bf236/lru_dict-1.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:cbbbb4b51e2529ccf7ee8a3c3b834052dbd54871a216cfd229dd2b1194ff293a", size = 32156, upload-time = "2025-11-02T10:01:20.22Z" }, + { url = "https://files.pythonhosted.org/packages/aa/d1/d9df7e9272ccbc96f04c477dfb9abb91fa8fabde86b7fa190cb7b3c7a024/lru_dict-1.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:e47040421a13de8bc6404557b3700c33f1f2683cbcce22fe5cacec4c938ce54b", size = 33395, upload-time = "2025-11-02T10:01:20.901Z" }, + { url = "https://files.pythonhosted.org/packages/e9/6e/dafe0f5943a7b3ab24d3429032ff85873acd626087934b8161b55340c13a/lru_dict-1.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:451f7249866cb9564bb40d73bec7ac865574dafd0a4cc91627bbf35be7e99291", size = 31591, upload-time = "2025-11-02T10:01:21.606Z" }, + { url = "https://files.pythonhosted.org/packages/a6/4d/9dd35444592bfb6805548e15971cfce821400966a51130b78dc021ee8f03/lru_dict-1.4.1-cp312-cp312-win32.whl", hash = "sha256:e8996f3f94870ecb236c55d280839390edae7f201858fee770267eac27b8b47d", size = 13119, upload-time = "2025-11-02T10:01:22.61Z" }, + { url = "https://files.pythonhosted.org/packages/8d/82/7e72e30d6c15d65466b3baca87cce15e20848ba6a488868aa54e901141a6/lru_dict-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:d90774db1b60c0d5c829cfa5d7fda6db96ed1519296f626575598f9f170cca37", size = 14109, upload-time = "2025-11-02T10:01:23.322Z" }, + { url = "https://files.pythonhosted.org/packages/ec/de/18ac3957e1aa6674a0a828748c819265f79b524ff30cbb0ac7f08ab786c8/lru_dict-1.4.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:cc9dd191870555624bbf3903c8afa3f01815ca3256ed8b35cb323f0db3ce4f98", size = 10467, upload-time = "2025-11-02T10:02:05.717Z" }, + { url = "https://files.pythonhosted.org/packages/0c/53/2a0bedaa64950cc56ade72e2f5a292318473585d9a3adc797d13b38082e7/lru_dict-1.4.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:afdf92b332632aa6e4b8646e93723f50f41fece2a80a54d2b44e8ac67f913ceb", size = 10871, upload-time = "2025-11-02T10:02:06.353Z" }, + { url = "https://files.pythonhosted.org/packages/4e/e2/d5ea49d62ea142559fd9cafd8505d4a4f87a1d81953a9c99fa61e7ccbd6b/lru_dict-1.4.1-pp310-pypy310_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3d6770adafae25663b682420891a10a5894595f02b1e4d87766f7adc8e56e72a", size = 12969, upload-time = "2025-11-02T10:02:07.196Z" }, + { url = "https://files.pythonhosted.org/packages/a2/67/0672caac9a04dc9011f7a27fc2ec2003f0bfa008070b29940d05b4dae56a/lru_dict-1.4.1-pp310-pypy310_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:018cd3b41224ca81eb83cdf6db024409a920e5c1d3ce4e8b323cb66e24a73132", size = 13959, upload-time = "2025-11-02T10:02:08.267Z" }, + { url = "https://files.pythonhosted.org/packages/e3/7e/313385214a5011cf9fe8376928f66f70bfedc48d8f7ab424292224ed4907/lru_dict-1.4.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:781dbcf0c83160e525482a4ebcd7c5065851a6c7295f1cda78248a2029f23f39", size = 14084, upload-time = "2025-11-02T10:02:08.993Z" }, + { url = "https://files.pythonhosted.org/packages/8e/47/08c61cad038706b3a89b8c7587ec74ed9731c1e536329745cccb6c840916/lru_dict-1.4.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:9219f13e4101c064f70e1815d7c51f9be9e053983e74dfb7bcfdf92f5fcbb0e0", size = 10384, upload-time = "2025-11-02T10:02:09.656Z" }, + { url = "https://files.pythonhosted.org/packages/6b/a1/022c4d7c68c076370231488c97cf7451131fb9ca0d60d1b2785e7baa1f5b/lru_dict-1.4.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7e1ac7fb6e91e4d3212e153f9e2d98d163a4439b9bf9df247c22519262c26fe", size = 10822, upload-time = "2025-11-02T10:02:10.609Z" }, + { url = "https://files.pythonhosted.org/packages/65/b4/4c0a0877a77fececa9f58d804569e2aac1bfbe588e3a70e79647b5d8f7d4/lru_dict-1.4.1-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:23424321b761c43f3021a596565f8205ecec0e175822e7a5d9b2a175578aa7de", size = 12968, upload-time = "2025-11-02T10:02:11.405Z" }, + { url = "https://files.pythonhosted.org/packages/22/06/d7e393d07dc31e656330d5a058f34e972bf590e7dc882922b426f3aec4a0/lru_dict-1.4.1-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:804ee76f98afc3d50e9a2e9c835a6820877aa6391f2add520a57f86b3f55ec3a", size = 13904, upload-time = "2025-11-02T10:02:12.144Z" }, + { url = "https://files.pythonhosted.org/packages/e8/1e/0eee8bcc16bf01b265ac83e4b870596e2f3bcc40d88aa7ec25407180fe44/lru_dict-1.4.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:3be24e24c8998302ea1c28f997505fa6843f507aad3c7d5c3a82cc01c5c11be4", size = 14062, upload-time = "2025-11-02T10:02:12.878Z" }, +] + [[package]] name = "mako" version = "1.3.10" @@ -916,6 +1068,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "mip" +version = "1.17.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cbcbox" }, + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/69/8b7695d78b96e997a691814e992732d98fa0f92c5c2a2885ec607f759aba/mip-1.17.4.tar.gz", hash = "sha256:0e7ca54424614bb9670795cc22cb7f700baf5a12c59bbc25af10b723bf0b64eb", size = 9443521, upload-time = "2026-03-12T16:44:59.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/1e/b02b0c57c2304944c371be03eb0b4293f66c81048e09fc5ce11042225b2b/mip-1.17.4-py3-none-any.whl", hash = "sha256:3b140b2954a50595ad4f9f263027087add3d22129ea29382d2ea2bacf79b485d", size = 88148, upload-time = "2026-03-12T16:44:57.755Z" }, +] + [[package]] name = "ml-dtypes" version = "0.5.4" @@ -1312,10 +1477,16 @@ all = [ { name = "datasets" }, { name = "deepspeed", marker = "sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "diffusers" }, + { name = "fire" }, { name = "huggingface-hub" }, + { name = "hydra-core" }, + { name = "immutabledict" }, { name = "lief" }, + { name = "lru-dict" }, + { name = "mip" }, { name = "ml-dtypes" }, { name = "nltk" }, + { name = "omegaconf" }, { name = "onnx" }, { name = "onnx-graphsurgeon" }, { name = "onnxconverter-common" }, @@ -1325,9 +1496,13 @@ all = [ { name = "onnxruntime-gpu", version = "1.24.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "onnxscript" }, { name = "onnxslim" }, + { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "peft" }, { name = "polygraphy" }, + { name = "sentencepiece" }, { name = "transformers" }, + { name = "typeguard" }, { name = "wonderwords" }, ] dev = [ @@ -1339,11 +1514,17 @@ dev = [ { name = "datasets" }, { name = "deepspeed", marker = "sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "diffusers" }, + { name = "fire" }, { name = "huggingface-hub" }, + { name = "hydra-core" }, + { name = "immutabledict" }, { name = "lief" }, + { name = "lru-dict" }, + { name = "mip" }, { name = "ml-dtypes" }, { name = "mypy" }, { name = "nltk" }, + { name = "omegaconf" }, { name = "onnx" }, { name = "onnx-graphsurgeon" }, { name = "onnxconverter-common" }, @@ -1353,6 +1534,8 @@ dev = [ { name = "onnxruntime-gpu", version = "1.24.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "onnxscript" }, { name = "onnxslim" }, + { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "peft" }, { name = "polygraphy" }, { name = "pre-commit" }, @@ -1377,6 +1560,7 @@ dev = [ { name = "tox" }, { name = "tox-current-env" }, { name = "transformers" }, + { name = "typeguard" }, { name = "wonderwords" }, ] dev-docs = [ @@ -1401,7 +1585,6 @@ dev-test = [ { name = "pytest-cov" }, { name = "pytest-instafail" }, { name = "pytest-timeout" }, - { name = "sentencepiece" }, { name = "timm" }, { name = "torch-geometric" }, { name = "torchprofile" }, @@ -1417,6 +1600,7 @@ hf = [ { name = "huggingface-hub" }, { name = "nltk" }, { name = "peft" }, + { name = "sentencepiece" }, { name = "transformers" }, { name = "wonderwords" }, ] @@ -1436,6 +1620,17 @@ onnx = [ { name = "onnxslim" }, { name = "polygraphy" }, ] +puzzletron = [ + { name = "fire" }, + { name = "hydra-core" }, + { name = "immutabledict" }, + { name = "lru-dict" }, + { name = "mip" }, + { name = "omegaconf" }, + { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "typeguard" }, +] [package.metadata] requires-dist = [ @@ -1447,8 +1642,13 @@ requires-dist = [ { name = "datasets", marker = "extra == 'hf'", specifier = ">=3.0.0" }, { name = "deepspeed", marker = "sys_platform != 'darwin' and sys_platform != 'win32' and extra == 'hf'", specifier = ">=0.9.6" }, { name = "diffusers", marker = "extra == 'hf'", specifier = ">=0.32.2" }, + { name = "fire", marker = "extra == 'puzzletron'" }, { name = "huggingface-hub", marker = "extra == 'hf'", specifier = ">=0.24.0" }, + { name = "hydra-core", marker = "extra == 'puzzletron'", specifier = "==1.3.2" }, + { name = "immutabledict", marker = "extra == 'puzzletron'" }, { name = "lief", marker = "extra == 'onnx'" }, + { name = "lru-dict", marker = "extra == 'puzzletron'" }, + { name = "mip", marker = "extra == 'puzzletron'" }, { name = "ml-dtypes", marker = "extra == 'onnx'" }, { name = "mypy", marker = "extra == 'dev-lint'", specifier = "==1.17.1" }, { name = "ninja" }, @@ -1456,7 +1656,8 @@ requires-dist = [ { name = "numpy" }, { name = "nvidia-ml-py", specifier = ">=12" }, { name = "nvidia-modelopt", extras = ["all", "dev-docs", "dev-lint", "dev-test"], marker = "extra == 'dev'" }, - { name = "nvidia-modelopt", extras = ["hf", "onnx"], marker = "extra == 'all'" }, + { name = "nvidia-modelopt", extras = ["hf", "onnx", "puzzletron"], marker = "extra == 'all'" }, + { name = "omegaconf", marker = "extra == 'puzzletron'", specifier = "==2.3.0" }, { name = "onnx", marker = "extra == 'onnx'", specifier = "~=1.19.0" }, { name = "onnx-graphsurgeon", marker = "extra == 'onnx'" }, { name = "onnxconverter-common", marker = "extra == 'onnx'", specifier = "~=1.16.0" }, @@ -1468,6 +1669,7 @@ requires-dist = [ { name = "onnxscript", marker = "extra == 'onnx'" }, { name = "onnxslim", marker = "extra == 'onnx'", specifier = ">=0.1.76" }, { name = "packaging" }, + { name = "pandas", marker = "extra == 'puzzletron'" }, { name = "peft", marker = "extra == 'hf'", specifier = ">=0.17.0" }, { name = "polygraphy", marker = "extra == 'onnx'", specifier = ">=0.49.22" }, { name = "pre-commit", marker = "extra == 'dev-lint'", specifier = "==4.3.0" }, @@ -1482,7 +1684,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev-lint'", specifier = "==0.12.11" }, { name = "safetensors" }, { name = "scipy" }, - { name = "sentencepiece", marker = "extra == 'dev-test'" }, + { name = "sentencepiece", marker = "extra == 'hf'", specifier = ">=0.2.1" }, { name = "setuptools", specifier = ">=80" }, { name = "sphinx", marker = "extra == 'dev-docs'", specifier = "~=8.1.0" }, { name = "sphinx-argparse", marker = "extra == 'dev-docs'", specifier = ">=0.5.2" }, @@ -1494,15 +1696,29 @@ requires-dist = [ { name = "timm", marker = "extra == 'dev-test'" }, { name = "torch", specifier = ">=2.6" }, { name = "torch-geometric", marker = "extra == 'dev-test'" }, - { name = "torchprofile", marker = "extra == 'dev-test'", specifier = ">=0.0.4" }, + { name = "torchprofile", marker = "extra == 'dev-test'", specifier = "==0.0.4" }, { name = "torchvision", marker = "extra == 'dev-test'" }, { name = "tox", marker = "extra == 'dev-test'", specifier = ">4.18" }, { name = "tox-current-env", marker = "extra == 'dev-test'", specifier = ">=0.0.12" }, { name = "tqdm" }, { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53,<5.0" }, + { name = "typeguard", marker = "extra == 'puzzletron'" }, { name = "wonderwords", marker = "extra == 'hf'" }, ] -provides-extras = ["onnx", "hf", "dev-lint", "dev-docs", "dev-test", "all", "dev"] +provides-extras = ["onnx", "hf", "puzzletron", "dev-lint", "dev-docs", "dev-test", "all", "dev"] + +[[package]] +name = "omegaconf" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, +] [[package]] name = "onnx" @@ -2068,6 +2284,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/c5/e98d9c51f3d5300d5e40ad9037dd6b3b60736fd02ab68dcc98c96be7592d/pybind11-3.0.2-py3-none-any.whl", hash = "sha256:f8a6500548919cc33bcd220d5f984688326f574fa97f1107f2f4fdb4c6fb019f", size = 310158, upload-time = "2026-02-17T04:46:49.91Z" }, ] +[[package]] +name = "pycparser" +version = "3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/7d/92392ff7815c21062bea51aa7b87d45576f649f16458d78b7cf94b9ab2e6/pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29", size = 103492, upload-time = "2026-01-21T14:26:51.89Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992", size = 48172, upload-time = "2026-01-21T14:26:50.693Z" }, +] + [[package]] name = "pydantic" version = "2.12.5" @@ -2873,6 +3098,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "termcolor" +version = "3.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/46/79/cf31d7a93a8fdc6aa0fbb665be84426a8c5a557d9240b6239e9e11e35fc5/termcolor-3.3.0.tar.gz", hash = "sha256:348871ca648ec6a9a983a13ab626c0acce02f515b9e1983332b17af7979521c5", size = 14434, upload-time = "2025-12-29T12:55:21.882Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/d1/8bb87d21e9aeb323cc03034f5eaf2c8f69841e40e4853c2627edf8111ed3/termcolor-3.3.0-py3-none-any.whl", hash = "sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5", size = 7734, upload-time = "2025-12-29T12:55:20.718Z" }, +] + [[package]] name = "timm" version = "1.0.25" @@ -3089,6 +3323,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" }, ] +[[package]] +name = "typeguard" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2b/e8/66e25efcc18542d58706ce4e50415710593721aae26e794ab1dec34fb66f/typeguard-4.5.1.tar.gz", hash = "sha256:f6f8ecbbc819c9bc749983cc67c02391e16a9b43b8b27f15dc70ed7c4a007274", size = 80121, upload-time = "2026-02-19T16:09:03.392Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/88/b55b3117287a8540b76dbdd87733808d4d01c8067a3b339408c250bb3600/typeguard-4.5.1-py3-none-any.whl", hash = "sha256:44d2bf329d49a244110a090b55f5f91aa82d9a9834ebfd30bcc73651e4a8cc40", size = 36745, upload-time = "2026-02-19T16:09:01.6Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From 1b42f0b16e090655d57f528808e84e3b7d16a998 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 17 Mar 2026 11:56:01 +0100 Subject: [PATCH 41/62] Dkorzekwa/any model other models (#1007) ### What does this PR do? Merging dkorzekwa/any_model_other_models into dkorzekwa/mip_and_realize_models - this MR is only for reviewing. Ultimately dkorzekwa/any_model_other_models should be merged into feature/puzzletron once dkorzekwa/mip_and_realize_models is merged there. ## Summary by CodeRabbit * **New Features** * Added support for multiple model architectures: Mistral Small, Nemotron H, Nemotron H v2, Qwen2, Qwen3 8B, and Qwen3 VL 30B. * Introduced new pruning configurations and optimization pipelines for supported models. * Added comprehensive model descriptor framework enabling automated weight conversion and configuration handling. * Extended support for Mixture of Experts (MoE) models with expert removal pruning capabilities. * **Tests** * Enhanced test coverage with parametrized configurations for multiple model variants. --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../nas/plugins/megatron_hooks/base_hooks.py | 72 ++--- .../activation_hooks/utils.py | 31 ++- .../anymodel/converter/converter.py | 8 +- .../model_descriptor/model_descriptor.py | 12 + .../puzzletron/anymodel/models/__init__.py | 12 +- .../models/llama/llama_model_descriptor.py | 10 + .../anymodel/models/mistral_small/__init__.py | 21 ++ .../mistral_small/mistral_small_converter.py | 41 +++ .../mistral_small_model_descriptor.py | 135 +++++++++ .../anymodel/models/nemotron_h/__init__.py | 21 ++ .../models/nemotron_h/nemotron_h_converter.py | 84 ++++++ .../nemotron_h/nemotron_h_model_descriptor.py | 256 ++++++++++++++++++ .../anymodel/models/nemotron_h_v2/__init__.py | 21 ++ .../nemotron_h_v2/nemotron_h_v2_converter.py | 84 ++++++ .../nemotron_h_v2_model_descriptor.py | 241 +++++++++++++++++ .../anymodel/models/qwen2/__init__.py | 19 ++ .../anymodel/models/qwen2/qwen2_converter.py | 50 ++++ .../models/qwen2/qwen2_model_descriptor.py | 148 ++++++++++ .../anymodel/models/qwen3_8b/__init__.py | 19 ++ .../models/qwen3_8b/qwen3_8b_converter.py | 42 +++ .../qwen3_8b/qwen3_8b_model_descriptor.py | 152 +++++++++++ .../qwen3_vl_30b_a3b_instruct/__init__.py | 21 ++ .../qwen3_vl_30b_a3b_instruct_converter.py | 77 ++++++ ...n3_vl_30b_a3b_instruct_model_descriptor.py | 212 +++++++++++++++ .../puzzletron/anymodel/puzzformer/no_op.py | 2 +- .../decilm/deci_lm_hf_code/modeling_decilm.py | 6 +- modelopt/torch/puzzletron/mip/run_puzzle.py | 11 +- .../pruning/expert_removal_pruning_mixin.py | 2 - .../torch/puzzletron/pruning/pruning_ckpts.py | 15 +- .../build_replacement_library.py | 57 +++- .../replacement_library.py | 2 + .../subblock_stats/calc_subblock_stats.py | 3 +- .../init_child_from_parent.py | 10 +- .../puzzletron/tools/checkpoint_utils.py | 7 +- .../puzzletron/tools/checkpoint_utils_hf.py | 28 +- .../tools/sharded_checkpoint_utils.py | 19 +- .../configs/pruning/ffn_pruning.yaml | 12 - .../configs/pruning/pruning_defaults.yaml | 32 --- .../configs/validate_model_defaults.yaml | 17 -- .../tokenizer/special_tokens_map.json | 16 -- .../resources/tokenizer/tokenizer.json | 212 --------------- .../resources/tokenizer/tokenizer_config.json | 13 - .../resources/tokenizer/truncate_tokenizer.py | 62 ----- tests/_test_utils/torch/puzzletron/utils.py | 50 ++-- .../nas/plugins/test_nas_convert.py | 19 +- .../puzzletron/nas/plugins/test_nas_search.py | 10 +- .../Qwen2.5-7B-Instruct.yaml} | 30 +- .../pruning/ffn_pruning.yaml | 7 + .../configs/Qwen/Qwen3-8B/Qwen3-8B.yaml} | 31 ++- .../Qwen/Qwen3-8B/pruning/ffn_pruning.yaml | 7 + .../Qwen3-VL-30B-A3B-Instruct.yaml | 113 ++++++++ .../pruning/expert_pruning.yaml | 20 ++ .../pruning/attn_pruning.yaml | 16 -- .../pruning/hidden_dim_pruning.yaml | 15 - .../validate_solutions_defaults.yaml | 10 - .../Llama-3.1-8B-Instruct-attn-pruning.yaml | 10 + .../Llama-3.1-8B-Instruct.yaml} | 9 +- .../pruning/attn_pruning.yaml | 7 + .../pruning/ffn_pruning.yaml | 7 + .../Llama-3.2-3B-Instruct.yaml} | 9 +- .../pruning/ffn_pruning.yaml | 7 + .../Mistral-Small-24B-Instruct-2501.yaml | 112 ++++++++ .../pruning/ffn_pruning.yaml | 7 + ...DIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml | 115 ++++++++ .../pruning/expert_pruning.yaml | 18 ++ .../pruning/ffn_pruning.yaml | 14 + .../NVIDIA-Nemotron-Nano-12B-v2.yaml | 113 ++++++++ .../pruning/ffn_pruning.yaml | 12 + .../configs/pruning/attn_pruning.yaml | 9 +- .../ffn_pruning_base.yaml} | 7 +- .../configs/pruning/hidden_dim_pruning.yaml | 2 +- .../pruning/pruning_defaults.yaml | 5 +- .../validate_model_defaults.yaml | 0 .../configs/validate_solutions_defaults.yaml | 0 .../llama_3_1_8b_instruct/config.json | 38 --- tests/gpu/torch/puzzletron/test_puzzletron.py | 142 +++++----- tox.ini | 2 + 77 files changed, 2592 insertions(+), 696 deletions(-) create mode 100644 modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen3_8b/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_model_descriptor.py delete mode 100644 tests/_test_utils/torch/puzzletron/resources/configs/pruning/ffn_pruning.yaml delete mode 100644 tests/_test_utils/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml delete mode 100644 tests/_test_utils/torch/puzzletron/resources/configs/validate_model_defaults.yaml delete mode 100644 tests/_test_utils/torch/puzzletron/resources/tokenizer/special_tokens_map.json delete mode 100644 tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer.json delete mode 100644 tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer_config.json delete mode 100644 tests/_test_utils/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py rename tests/{_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-ffn-pruning.yaml => gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml} (76%) create mode 100644 tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml rename tests/{_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-attn-pruning.yaml => gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml} (76%) create mode 100644 tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml delete mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml delete mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml delete mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml rename tests/gpu/torch/puzzletron/resources/configs/{llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml => meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml} (94%) create mode 100644 tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml rename tests/gpu/torch/puzzletron/resources/configs/{llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml => meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml} (94%) create mode 100644 tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml rename tests/{_test_utils => gpu}/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml (67%) rename tests/gpu/torch/puzzletron/resources/configs/{llama_3_1_8b_instruct/pruning/ffn_pruning.yaml => pruning/ffn_pruning_base.yaml} (72%) rename tests/{_test_utils => gpu}/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml (93%) rename tests/gpu/torch/puzzletron/resources/configs/{llama_3_1_8b_instruct => }/pruning/pruning_defaults.yaml (94%) rename tests/gpu/torch/puzzletron/resources/configs/{llama_3_1_8b_instruct => }/validate_model_defaults.yaml (100%) rename tests/{_test_utils => gpu}/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml (100%) delete mode 100644 tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py index 7cd7214443..a868fddc13 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -602,9 +602,9 @@ def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): assert self.optimize_for in ["latency", "memory"] self.hidden_size = model_config.hidden_size - self.n_heads_in_group = block_config.attention.n_heads_in_group self.num_q_heads = model_config.num_attention_heads - self.num_kv_heads = self.num_q_heads // self.n_heads_in_group + self.num_kv_heads = block_config.attention.num_key_value_heads + self.n_heads_in_group = self.num_q_heads // self.num_kv_heads self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads) self.agg_kv_head_contributions = torch.zeros( @@ -1142,61 +1142,39 @@ def __call__( class Qwen3VLRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): - """Expert removal importance hook for Qwen3-VL models. - - TODO: Implement get_router_logits_and_routed_experts based on Qwen3-VL MoE forward pass. - """ + """Expert removal importance hook for Qwen3-VL models.""" def get_router_logits_and_routed_experts( self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Extract router logits and expert outputs for Qwen3-VL MoE. - Note: This is a placeholder implementation. Implement based on Qwen3VLMoeSparseMoe forward. + Based on Qwen3VLMoeSparseMoe forward pass. """ - batch_size = ( - hidden_states.shape[0] * hidden_states.shape[1] - if hidden_states.ndim > 2 - else hidden_states.shape[0] - ) - router_logits_out = torch.zeros( - batch_size, self.num_local_experts, device=hidden_states.device - ) - routed_experts = hidden_states.view(-1, hidden_states.shape[-1]) - return router_logits_out, routed_experts + orig_shape = hidden_states.shape + # Flatten to (num_tokens, hidden_size) for processing + hidden_states_flat = hidden_states.reshape(-1, self.moe.hidden_size) -class GptOssRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): - """Expert removal importance hook for GPT-OSS models. + if router_logits is None: + router_logits = self.moe.gate(hidden_states_flat) + + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.moe.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states_flat.dtype) + router_weights = torch.zeros_like(router_logits).scatter_( + 1, router_indices, routing_weights + ) - TODO: Implement get_router_logits_and_routed_experts based on GPT-OSS MoE forward pass. - This is a placeholder implementation that allows the framework to run. - """ + # Reshape hidden_states for moe.experts (expects 3D: batch, seq, hidden) + # router_weights and router_indices remain 2D (num_tokens, num_experts) + batch_size = orig_shape[0] if hidden_states.ndim == 3 else 1 + hidden_states_3d = hidden_states_flat.reshape(batch_size, -1, self.moe.hidden_size) - def get_router_logits_and_routed_experts( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor]: - """Extract router logits and expert outputs for GPT-OSS MoE. + routed_out = self.moe.experts(hidden_states_3d, router_weights, router_indices) - Note: This is a placeholder implementation. For proper expert scoring, - implement based on GptOssSparseMoeBlock forward pass. + # Return in same shape as input + routed_out = routed_out.reshape(*orig_shape) - Args: - hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) - router_logits: Optional pre-computed router logits - - Returns: - tuple of (router_logits, routed_experts): - - router_logits: Shape (num_tokens, num_local_experts) - zeros as placeholder - - routed_experts: Original hidden states (no-op) - """ - batch_size = ( - hidden_states.shape[0] * hidden_states.shape[1] - if hidden_states.ndim > 2 - else hidden_states.shape[0] - ) - router_logits_out = torch.zeros( - batch_size, self.num_local_experts, device=hidden_states.device - ) - routed_experts = hidden_states.view(-1, hidden_states.shape[-1]) - return router_logits_out, routed_experts + return router_logits, routed_out diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py index 1b1485c713..33243c0125 100644 --- a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py @@ -19,8 +19,11 @@ from typing import Type +import torch + from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook as ActivationsHook from modelopt.torch.puzzletron.tools.logger import aprint +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock, DummyModule def register_activation_hooks( @@ -51,6 +54,16 @@ def register_activation_hooks( module_names_to_hook = pruning_mixin.get_module_names_to_hook(model) activation_hooks = dict() for block_idx, module_name in module_names_to_hook: + try: + module = model.get_submodule(module_name) + except AttributeError: + # Module doesn't exist on this rank's shard (e.g., in distributed setup) + continue + + # Skip dummy modules - they don't have real activations to hook + if isinstance(module, (DummyModule, DummyBlock)): + continue + block_config = None if block_idx is not None: block_config = model.config.block_configs[block_idx] @@ -59,13 +72,25 @@ def register_activation_hooks( "block_config": block_config, } - module = model.get_submodule(module_name) hook = hook_class(module, curr_activation_hooks_kwargs) module.register_forward_hook(hook) activation_hooks[module_name] = hook if len(activation_hooks) == 0: - raise ValueError("couldn't find any hooks") + # In distributed mode, it's okay for a rank to have 0 hooks if it doesn't own + # the target modules (e.g., with hybrid patterns like "*-" where different + # ranks own different layer types). However, we still want to catch real bugs + # where no hooks are found at all. + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + if is_distributed: + aprint( + "No hooks registered on this rank. This is expected if this rank " + "doesn't own any layers matching the hook pattern (e.g., in hybrid " + "patterns with distributed model sharding)." + ) + else: + raise ValueError("couldn't find any hooks") - aprint(f"Found the following hooks: {activation_hooks.keys()}") + if len(activation_hooks) > 0: + aprint(f"Found the following hooks: {activation_hooks.keys()}") return activation_hooks diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter.py b/modelopt/torch/puzzletron/anymodel/converter/converter.py index 5fdc92718c..eb2330b515 100644 --- a/modelopt/torch/puzzletron/anymodel/converter/converter.py +++ b/modelopt/torch/puzzletron/anymodel/converter/converter.py @@ -135,9 +135,10 @@ def convert_configs_in_dirs( cls, input_dir: Path, output_dir: Path, + trust_remote_code: bool = False, ): """Convert config and add block_configs.""" - config = load_model_config(input_dir) + config = load_model_config(input_dir, trust_remote_code=trust_remote_code) block_configs = cls.create_block_configs_from_main_config(config) out_config = copy.deepcopy(config) @@ -179,7 +180,10 @@ def convert( output_dir: Path to the output AnyModel checkpoint. """ cls.copy_checkpoint_files(input_dir, output_dir) - config = cls.convert_configs_in_dirs(input_dir, output_dir) + trust_remote_code = descriptor.requires_trust_remote_code() + config = cls.convert_configs_in_dirs( + input_dir, output_dir, trust_remote_code=trust_remote_code + ) cls.convert_model_weights( input_dir, output_dir, descriptor=descriptor, num_hidden_layers=config.num_hidden_layers ) diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py index 73d56d2016..4cc4356c8e 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py @@ -53,6 +53,18 @@ def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any] """ raise NotImplementedError + @staticmethod + def requires_trust_remote_code() -> bool: + """Whether this model descriptor requires trust_remote_code=True for loading. + + Models that use custom code (e.g., via auto_map in config) should override + this to return True. + + Returns: + True if trust_remote_code=True is required, False otherwise. + """ + return False + @staticmethod def mlp_no_op_post_init(decoder_layer: nn.Module): """Post-init callback to alter a decoder layer so that FFN/mlp subblock performs as no-op. diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py index f2119059f4..1f3fb477be 100644 --- a/modelopt/torch/puzzletron/anymodel/models/__init__.py +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -16,9 +16,9 @@ # Import models to trigger factory registration # from modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b import * from modelopt.torch.puzzletron.anymodel.models.llama import * -# from modelopt.torch.puzzletron.anymodel.models.mistral_small import * -# from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * -# from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import * -# from modelopt.torch.puzzletron.anymodel.models.qwen2 import * -# from modelopt.torch.puzzletron.anymodel.models.qwen3_8b import * -# from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct import * +from modelopt.torch.puzzletron.anymodel.models.mistral_small import * +from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * +from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import * +from modelopt.torch.puzzletron.anymodel.models.qwen2 import * +from modelopt.torch.puzzletron.anymodel.models.qwen3_8b import * +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct import * diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py index fe416e2dd6..082e5da599 100644 --- a/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py @@ -39,6 +39,7 @@ from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( FFNIntermediateLayerDescriptor, ) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor @ModelDescriptorFactory.register_decorator("llama") @@ -129,3 +130,12 @@ class LlamaFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): linear_weight_names: List[str] = field( default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] ) + + +@dataclass +class LlamaKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py new file mode 100644 index 0000000000..821be47e9d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_converter import ( + MistralSmallConverter, +) +from modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor import ( + MistralSmallModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py new file mode 100644 index 0000000000..ddc8151dc9 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from typing import List + +from transformers import MistralConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("mistral_small") +class MistralSmallConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: MistralConfig) -> List[BlockConfig]: + num_hidden_layers = config.num_hidden_layers + + block_config = BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + + block_configs = [block_config] * num_hidden_layers + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py new file mode 100644 index 0000000000..1ac2bd7072 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralForCausalLM, + MistralRotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor + + +@ModelDescriptorFactory.register_decorator("mistral_small") +class MistralSmallModelDescriptor(ModelDescriptor): + @staticmethod + def decoder_layer_cls(): + return MistralDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: MistralDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: MistralDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: MistralForCausalLM, runtime): + model.model.rotary_emb = MistralRotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class MistralFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class MistralKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py new file mode 100644 index 0000000000..a2140f118e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_converter import ( + NemotronHConverter, +) +from modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor import ( + NemotronHModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py new file mode 100644 index 0000000000..16d9e3c73d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MambaConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("nemotron_h") +class NemotronHConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config) -> List[BlockConfig]: + # Create block configs for each layer based on the hybrid_override_pattern + block_configs = [] + + # Parse the hybrid_override_pattern: "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-" + pattern = config.hybrid_override_pattern + print(f"Parsing hybrid pattern: {pattern}") + + for i, char in enumerate(pattern): + if char == "M": + _block_config = BlockConfig( + attention=AttentionConfig( + mamba=MambaConfig( # Those parameters are currently used only for calc_block_stats. + state_dim=config.ssm_state_size, + num_heads=config.mamba_num_heads, + head_dim=config.mamba_head_dim, + num_groups=config.n_groups, + ) + ), + ffn=FFNConfig(no_op=True), + ) + + elif char == "-": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig(intermediate_size=config.intermediate_size), + ) + + elif char == "*": + _block_config = BlockConfig( + attention=AttentionConfig(num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=True), + ) + + elif char == "E": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=config.n_routed_experts, + expert_intermediate_dim=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + ) + ), + ) + else: + raise ValueError( + f"Unknown character '{char}' in hybrid_override_pattern at position {i}" + ) + + block_configs.append(_block_config) + + print(f"Created {len(block_configs)} block configs from pattern") + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py new file mode 100644 index 0000000000..55d9ef56ca --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import importlib +import inspect +import pkgutil +import re +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Tuple, Type + +import torch.nn as nn + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import MatchingZeros, Same +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn + + +def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: + import transformers_modules + + matches = [] + for finder, modname, ispkg in pkgutil.walk_packages( + transformers_modules.__path__, transformers_modules.__name__ + "." + ): + module = importlib.import_module(modname) + for _, obj in inspect.getmembers(module, inspect.isclass): + if obj.__name__ == module_cls_str: + matches.append(obj) + + return matches + + +@dataclass +class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + target_name: str = "mixer.gate" + moe_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + expert_prefix_name: str = "experts.{expert_idx}" + router_weights: List[str] = field(default_factory=lambda: ["gate.weight"]) + router_biases: List[str] = field(default_factory=lambda: ["gate.e_score_correction_bias"]) + expert_weights: List[str] = field( + default_factory=lambda: ["up_proj.weight", "down_proj.weight"] + ) + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + if self.target_name != "mixer": + return super().get_modules_names_to_hook(model) + + # when target is `mixer` we'll target moe layers of class type: `NemotronHMOE`, as NemotronH models use auto-map we'll check for class name instead of class type. + target_class_name = "NemotronHMOE" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + # restrict to attributes called "mixer" and with the desired class name + if ( + module_name.endswith(self.target_name) + and module.__class__.__name__ == target_class_name + ): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook + + +@ModelDescriptorFactory.register_decorator("nemotron_h") +class NemotronHModelDescriptor(ModelDescriptor): + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @staticmethod + def decoder_layer_cls(): + decoder_cls_list = get_dynamic_modules("NemotronHBlock") + if not decoder_cls_list: + raise AssertionError( + "NemotronH contains dynamic modules that should be cached beforehand, make sure to load your config using `load_model_config` or manually call `force_cache_dynamic_modules(config, checkpoint_dir)`" + ) + return decoder_cls_list + + @staticmethod + def requires_trust_remote_code() -> bool: + return True + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {} + if block_config.ffn.intermediate_size is not None: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + if block_config.attention.num_key_value_heads is not None: + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["n_routed_experts"] = block_config.ffn.moe.num_local_experts + + return override_kwargs + + @staticmethod + def _block_no_op_post_init(decoder_layer): + """ + Due to the subblock structure of NemotronH always one of the subblock is set to no-op, for a real no-op both attention & ffn no-op should be set to True. + """ + block_config = decoder_layer.config.block_configs[decoder_layer.layer_idx] + if block_config.ffn.no_op and block_config.attention.no_op: + decoder_layer.norm = Same() + decoder_layer.mixer = MatchingZeros() + + @staticmethod + def attn_no_op_post_init(decoder_layer): + NemotronHModelDescriptor._block_no_op_post_init(decoder_layer) + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + NemotronHModelDescriptor._block_no_op_post_init(decoder_layer) + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + dummy_block = super().create_dummy_block(original_layer, block_index) + # Required by `NemotronHModel.forward`. + dummy_block.block_type = original_layer.block_type + # Preserve layer_idx if it exists (used by _block_no_op_post_init) + if hasattr(original_layer, "layer_idx"): + dummy_block.layer_idx = original_layer.layer_idx + # Preserve config if it exists (used by _block_no_op_post_init to access block_configs) + if hasattr(original_layer, "config"): + dummy_block.config = original_layer.config + return dummy_block + + @staticmethod + def init_rotary_embedding(model, runtime): + """ + NemotronH has no positional embeddings + """ + pass + + @staticmethod + def input_embedding_name(): + return "backbone.embeddings" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "backbone.norm_f" + + @staticmethod + def layer_block_name(index: int): + return f"backbone.layers.{index}" + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """ + Problem with NemotronH is that `norm.weight` can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with `norm.weight` should be removed. + """ + weight_groups = defaultdict(list) + for name in layer_names: + is_matched = False + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + is_matched = True + if not is_matched: + raise ValueError(f"Couldn't find a match for {name}") + + valid_weight_groups = {} + for group, names in weight_groups.items(): + if len(names) == 1: + only_name = names[0] + if only_name.endswith("norm.weight") and "layers" in only_name: + # Skip and don't append this group to valid_weight_groups + continue + valid_weight_groups[group] = names + + return valid_weight_groups + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile( + r"^(model\.embed_tokens\.weight|backbone\.embeddings\.weight)$" + ), + "lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN FFN + r"mixer\.(gate\.e_score_correction_bias" + r"|gate\.weight" + r"|experts\.\d+\.up_proj\.weight" + r"|experts\.\d+\.down_proj\.weight" + r"|shared_experts\.up_proj\.weight" + r"|shared_experts\.down_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN ATTENTION + r"mixer\.(norm\.weight" + r"|A_log" + r"|D" + r"|conv1d\.weight" + r"|conv1d\.bias" + r"|dt_bias" + r"|in_proj\.weight" + r"|out_proj\.weight" + r"|q_proj\.weight" + r"|k_proj\.weight" + r"|v_proj\.weight" + r"|o_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + } diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py new file mode 100644 index 0000000000..4b17785ace --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_converter import ( + NemotronHV2Converter, +) +from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor import ( + NemotronHV2ModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py new file mode 100644 index 0000000000..2c54388325 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MambaConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("nemotron_h_v2") +class NemotronHV2Converter(Converter): + @staticmethod + def create_block_configs_from_main_config(config) -> List[BlockConfig]: + # Create block configs for each layer based on the hybrid_override_pattern + block_configs = [] + + # Parse the hybrid_override_pattern: "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-" + pattern = config.hybrid_override_pattern + print(f"Parsing hybrid pattern: {pattern}") + + for i, char in enumerate(pattern): + if char == "M": + _block_config = BlockConfig( + attention=AttentionConfig( + mamba=MambaConfig( # Those parameters are currently used only for calc_block_stats. + state_dim=config.ssm_state_size, + num_heads=config.mamba_num_heads, + head_dim=config.mamba_head_dim, + num_groups=config.n_groups, + ) + ), + ffn=FFNConfig(no_op=True), + ) + + elif char == "-": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig(intermediate_size=config.intermediate_size), + ) + + elif char == "*": + _block_config = BlockConfig( + attention=AttentionConfig(num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=True), + ) + + elif char == "E": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=config.n_routed_experts, + expert_intermediate_dim=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + ) + ), + ) + else: + raise ValueError( + f"Unknown character '{char}' in hybrid_override_pattern at position {i}" + ) + + block_configs.append(_block_config) + + print(f"Created {len(block_configs)} block configs from pattern") + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py new file mode 100644 index 0000000000..f50217d4d3 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import pkgutil +import re +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Type + +import torch.nn as nn + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import MatchingZeros, Same +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, + FFNIntermediatePruningMixIn, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn + + +def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: + import transformers_modules + + matches = [] + for finder, modname, ispkg in pkgutil.walk_packages( + transformers_modules.__path__, transformers_modules.__name__ + "." + ): + module = importlib.import_module(modname) + for _, obj in inspect.getmembers(module, inspect.isclass): + if obj.__name__ == module_cls_str: + matches.append(obj) + + return matches + + +@dataclass +class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mixer.down_proj" + ffn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"]) + + +@ModelDescriptorFactory.register_decorator("nemotron_h_v2") +class NemotronHV2ModelDescriptor(ModelDescriptor): + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @staticmethod + def decoder_layer_cls(): + decoder_cls_list = get_dynamic_modules("NemotronHBlock") + if not decoder_cls_list: + raise AssertionError( + "NemotronH contains dynamic modules that should be cached beforehand, make sure to load your config using `load_model_config` or manually call `force_cache_dynamic_modules(config, checkpoint_dir)`" + ) + return decoder_cls_list + + @staticmethod + def requires_trust_remote_code() -> bool: + return True + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {} + if block_config.ffn is not None and block_config.ffn.intermediate_size is not None: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + if ( + block_config.attention is not None + and block_config.attention.num_key_value_heads is not None + ): + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn is not None and block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["n_routed_experts"] = block_config.ffn.moe.num_local_experts + + return override_kwargs + + @staticmethod + def _block_no_op_post_init(decoder_layer): + """ + Due to the subblock structure of NemotronH always one of the subblock is set to no-op, for a real no-op both attention & ffn no-op should be set to True. + """ + block_config = decoder_layer.config.block_configs[decoder_layer.layer_idx] + ffn_no_op = block_config.ffn is not None and block_config.ffn.no_op + attn_no_op = block_config.attention is not None and block_config.attention.no_op + if ffn_no_op and attn_no_op: + decoder_layer.norm = Same() + decoder_layer.mixer = MatchingZeros() + + @staticmethod + def attn_no_op_post_init(decoder_layer): + NemotronHV2ModelDescriptor._block_no_op_post_init(decoder_layer) + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + NemotronHV2ModelDescriptor._block_no_op_post_init(decoder_layer) + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + dummy_block = super().create_dummy_block(original_layer, block_index) + # Required by `NemotronHModel.forward`. + dummy_block.block_type = original_layer.block_type + # Preserve layer_idx if it exists (used by _block_no_op_post_init) + if hasattr(original_layer, "layer_idx"): + dummy_block.layer_idx = original_layer.layer_idx + # Preserve config if it exists (used by _block_no_op_post_init to access block_configs) + if hasattr(original_layer, "config"): + dummy_block.config = original_layer.config + return dummy_block + + @staticmethod + def init_rotary_embedding(model, runtime): + """ + NemotronH has no positional embeddings + """ + pass + + @staticmethod + def input_embedding_name(): + return "backbone.embeddings" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "backbone.norm_f" + + @staticmethod + def layer_block_name(index: int): + return f"backbone.layers.{index}" + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """ + Problem with NemotronH is that `norm.weight` can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with `norm.weight` should be removed. + """ + weight_groups = defaultdict(list) + for name in layer_names: + is_matched = False + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + is_matched = True + if not is_matched: + raise ValueError(f"Couldn't find a match for {name}") + + valid_weight_groups = {} + for group, names in weight_groups.items(): + if len(names) == 1: + only_name = names[0] + if only_name.endswith("norm.weight") and "layers" in only_name: + # Skip and don't append this group to valid_weight_groups + continue + valid_weight_groups[group] = names + + return valid_weight_groups + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile( + r"^(model\.embed_tokens\.weight|backbone\.embeddings\.weight)$" + ), + "lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN FFN + r"mixer\.(gate\.e_score_correction_bias" + r"|gate\.weight" + r"|experts\.\d+\.up_proj\.weight" + r"|experts\.\d+\.down_proj\.weight" + r"|shared_experts\.up_proj\.weight" + r"|shared_experts\.down_proj\.weight" + r"|up_proj\.weight" # Simple MLP (non-MoE) + r"|down_proj\.weight))$" # Simple MLP (non-MoE) + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN ATTENTION + r"mixer\.(norm\.weight" + r"|A_log" + r"|D" + r"|conv1d\.weight" + r"|conv1d\.bias" + r"|dt_bias" + r"|in_proj\.weight" + r"|out_proj\.weight" + r"|q_proj\.weight" + r"|k_proj\.weight" + r"|v_proj\.weight" + r"|o_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "ffn_intermediate": FFNIntermediatePruningMixIn( + NemotronHV2FFNIntermediateLayerDescriptor() + ), + # TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated + } diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py new file mode 100644 index 0000000000..c193fc0d6d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_converter import Qwen2Converter +from modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor import ( + Qwen2ModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py new file mode 100644 index 0000000000..878cfd64dc --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Qwen2 converter for AnyModel compression.""" + +from typing import List + +from transformers import Qwen2Config + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("qwen2") +class Qwen2Converter(Converter): + """Converter for Qwen2 models to AnyModel format.""" + + @staticmethod + def create_block_configs_from_main_config(config: Qwen2Config) -> List[BlockConfig]: + """Create uniform block configs for all Qwen2 layers. + + Qwen2 models have uniform architecture across all layers, so we create + the same BlockConfig for each layer. + """ + num_hidden_layers = config.num_hidden_layers + + block_config = BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + + block_configs = [block_config] * num_hidden_layers + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py new file mode 100644 index 0000000000..69185d1de3 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Qwen2 model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass +from typing import Dict + +from torch import nn +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2DecoderLayer, + Qwen2ForCausalLM, + Qwen2RotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import ( + LlamaFFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + + +@ModelDescriptorFactory.register_decorator("qwen2") +class Qwen2ModelDescriptor(ModelDescriptor): + """Model descriptor for Qwen2 models.""" + + @staticmethod + def decoder_layer_cls(): + return Qwen2DecoderLayer + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block that preserves Qwen2-specific attributes like attention_type. + + Qwen2's forward pass accesses decoder_layer.attention_type for attention mask selection. + """ + dummy = DummyBlock(block_index=block_index) + # Copy attention_type from original layer (required by Qwen2's forward pass) + if hasattr(original_layer, "attention_type"): + dummy.attention_type = original_layer.attention_type + return dummy + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: Qwen2DecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: Qwen2DecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: Qwen2ForCausalLM, runtime): + model.model.rotary_emb = Qwen2RotaryEmbedding(config=model.config, device=runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + # Qwen2 has biases on attention projections + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.q_proj\.bias" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.k_proj\.bias" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.v_proj\.bias" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen2FFNIntermediateLayerDescriptor(LlamaFFNIntermediateLayerDescriptor): + """Layer descriptor for Qwen2 FFN intermediate pruning. + + Qwen2 uses the same FFN structure as Llama (gate_proj, up_proj, down_proj). + """ + + pass diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/__init__.py new file mode 100644 index 0000000000..0f753f705d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_converter import Qwen3_8BConverter +from modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor import ( + Qwen3_8BModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_converter.py new file mode 100644 index 0000000000..1a389291df --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_converter.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +from typing import List + +from transformers import Qwen3Config + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("qwen3") +class Qwen3_8BConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: Qwen3Config) -> List[BlockConfig]: + num_hidden_layers = config.num_hidden_layers + + block_config = BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + + block_configs = [block_config] * num_hidden_layers + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_model_descriptor.py new file mode 100644 index 0000000000..679ee73fae --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_model_descriptor.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from torch import nn +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3RotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + + +@ModelDescriptorFactory.register_decorator("qwen3") +class Qwen3_8BModelDescriptor(ModelDescriptor): + @staticmethod + def decoder_layer_cls(): + return Qwen3DecoderLayer + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block that preserves Qwen3-specific attributes like attention_type. + + Qwen3's forward pass accesses decoder_layer.attention_type for attention mask selection. + """ + dummy = DummyBlock(block_index=block_index) + # Copy attention_type from original layer (required by Qwen3's forward pass) + if hasattr(original_layer, "attention_type"): + dummy.attention_type = original_layer.attention_type + return dummy + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: Qwen3DecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: Qwen3DecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: Qwen3ForCausalLM, runtime): + model.model.rotary_emb = Qwen3RotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.q_norm\.weight" + r"|self_attn\.k_norm\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen3_8BFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class Qwen3_8BKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/__init__.py new file mode 100644 index 0000000000..7bf317d29e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_converter import ( + Qwen3VL30BA3BInstructConverter, +) +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_model_descriptor import ( + Qwen3VL30BA3BInstructModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_converter.py new file mode 100644 index 0000000000..0c50dfeb9e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_converter.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +from typing import List + +from transformers import Qwen3VLMoeConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("qwen3_vl") +class Qwen3VL30BA3BInstructConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: Qwen3VLMoeConfig) -> List[BlockConfig]: + # Qwen3-VL MoE has nested text_config + text_config = config.text_config if hasattr(config, "text_config") else config + + num_hidden_layers = text_config.num_hidden_layers + decoder_sparse_step = getattr(text_config, "decoder_sparse_step", 1) + mlp_only_layers = getattr(text_config, "mlp_only_layers", []) + + block_configs = [] + for layer_idx in range(num_hidden_layers): + # Check if this layer is MoE or dense + is_moe_layer = (layer_idx % decoder_sparse_step == 0) and ( + layer_idx not in mlp_only_layers + ) + + if is_moe_layer: + # MoE layer + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=text_config.num_key_value_heads + ), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=text_config.num_experts, + expert_intermediate_dim=text_config.moe_intermediate_size, + num_experts_per_tok=text_config.num_experts_per_tok, + ) + ), + ) + else: + # Dense layer + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=text_config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=text_config.intermediate_size), + ) + + block_configs.append(block_config) + + print( + f"Created {len(block_configs)} block configs for Qwen3-VL MoE (decoder_sparse_step={decoder_sparse_step})" + ) + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_model_descriptor.py new file mode 100644 index 0000000000..7c7665a644 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_model_descriptor.py @@ -0,0 +1,212 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +import torch.nn as nn +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextRotaryEmbedding, + Qwen3VLMoeVisionRotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor + + +@ModelDescriptorFactory.register_decorator("qwen3_vl") +class Qwen3VL30BA3BInstructModelDescriptor(ModelDescriptor): + @staticmethod + def uses_autocast() -> bool: + """ + Qwen3-VL MoE has a dtype bug in HuggingFace transformers under torch.autocast: + scatter() in MoE routing fails with dtype mismatch. Use native bfloat16 instead. + See: https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct (recommended approach) + """ + return False + + @staticmethod + def get_language_model_config(config): + """Qwen3-VL has nested text_config for language model parameters.""" + return config.text_config if hasattr(config, "text_config") else config + + @staticmethod + def decoder_layer_cls(): + return Qwen3VLMoeTextDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {"num_key_value_heads": block_config.attention.num_key_value_heads} + + if block_config.ffn.moe: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["num_experts"] = block_config.ffn.moe.num_local_experts + else: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + return override_kwargs + + @staticmethod + def attn_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model, runtime): + # Re-initialize text rotary embedding on correct device and dtype + text_config = Qwen3VL30BA3BInstructModelDescriptor.get_language_model_config(model.config) + model.model.language_model.rotary_emb = Qwen3VLMoeTextRotaryEmbedding( + config=text_config + ).to(device=runtime.device, dtype=runtime.dtype) + # Re-initialize vision rotary embedding on correct device and dtype + vision_config = ( + model.config.vision_config if hasattr(model.config, "vision_config") else None + ) + if vision_config is not None: + head_dim = vision_config.hidden_size // vision_config.num_heads + model.model.visual.rotary_pos_emb = Qwen3VLMoeVisionRotaryEmbedding(head_dim // 2).to( + device=runtime.device, dtype=runtime.dtype + ) + + @staticmethod + def input_embedding_name(): + return "model.language_model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.language_model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.language_model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + # Qwen3-VL has text model under model.language_model.* prefix + layer_name_patterns = { + "embeddings": re.compile(r"^model\.language_model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.language_model\.norm\.weight|lm_head\.weight)$"), + # Vision encoder (includes merger under model.visual.deepstack_merger_list.*) + "vision_encoding": re.compile(r"^model\.visual\..*"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.language_model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + # MoE router + r"|mlp\.gate\.weight" + # MoE experts - fused format (gate_up_proj, down_proj without .weight suffix) + r"|mlp\.experts\.gate_up_proj" + r"|mlp\.experts\.down_proj" + # Shared expert (if present) + r"|mlp\.shared_expert\.up_proj\.weight" + r"|mlp\.shared_expert\.gate_proj\.weight" + r"|mlp\.shared_expert\.down_proj\.weight" + r"|mlp\.shared_expert_gate\.weight" + # Dense MLP fallback (for non-MoE layers) + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.language_model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.q_norm\.weight" + r"|self_attn\.k_norm\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen3VL30BA3BInstructFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class Qwen3VL30BA3BInstructKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.language_model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + +@dataclass +class Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + """ + Qwen3-VL MoE layer descriptor. + + Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py + - Qwen3VLMoeTextSparseMoeBlock: MoE block with .gate (router) and .experts + - Qwen3VLMoeTextTopKRouter: Router with .weight (no bias) + - Qwen3VLMoeTextExperts: Fused experts with .gate_up_proj and .down_proj tensors + """ + + target_name: str = "mlp" + moe_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp" + # Router: Qwen3VLMoeTextTopKRouter has self.weight, no bias + router_weights: List[str] = field(default_factory=lambda: ["gate.weight"]) + router_biases: List[str] = field(default_factory=lambda: []) + # Fused expert format: Qwen3VLMoeTextExperts stores all experts in single tensors + # with shape [num_experts, ...] instead of separate tensors per expert. + is_fused_experts: bool = True + fused_expert_weights: List[str] = field( + default_factory=lambda: ["experts.gate_up_proj", "experts.down_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py index aac57af0a9..9b3a9a2190 100644 --- a/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py @@ -43,7 +43,7 @@ class Wrapped(cls): def forward(self, *args, **kwargs): result = super().forward(*args, **kwargs) outputs = [None] * size - outputs[0] = result[0] + outputs[0] = result if isinstance(result, torch.Tensor) else result[0] return tuple(outputs) def extra_repr(self) -> str: diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py index 22d00ea773..24be1b227d 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py @@ -534,7 +534,7 @@ def __init__( self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[ffn_config.hidden_act] + self.act_fn = ACT2FN[getattr(ffn_config, "hidden_act", "silu")] if ffn_config.sparsify is not None: self.register_full_backward_hook(sparsity_backward_hook) @@ -579,7 +579,7 @@ def __init__( self.intermediate_size = ffn_config.intermediate_size self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[ffn_config.hidden_act] + self.act_fn = ACT2FN[getattr(ffn_config, "hidden_act", "silu")] if ffn_config.sparsify is not None: self.register_full_backward_hook(sparsity_backward_hook) @@ -1037,7 +1037,7 @@ def __init__(self, config: DeciLMConfig, layer_idx: int | tuple[int, ...]): self.self_attn = DeciLMLlama4TextAttention(config, layer_idx, self.attention_config) if not (self.ffn_config.no_op or self.attention_config.is_mamba): - if self.ffn_config.hidden_act is None: + if getattr(self.ffn_config, "hidden_act", None) is None: print(f"WARNING: FFN hidden_act is None for layer {layer_idx}") self.post_attention_layernorm = DeciLMRMSNorm( diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py index da0f90452d..71913db7d3 100644 --- a/modelopt/torch/puzzletron/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -29,6 +29,10 @@ import yaml from omegaconf import DictConfig, ListConfig, OmegaConf +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, BlockConfig, @@ -558,7 +562,12 @@ def _parse_teacher_block_metrics( ) -> list[dict]: raw_metrics = json.loads((single_block_replacement_validation_dir / "teacher.json").read_text()) teacher_checkpoint_dir = Path(raw_metrics["args"]["teacher_dir"]).resolve() - teacher_model_config = load_model_config(teacher_checkpoint_dir) + descriptor_name = raw_metrics["args"]["descriptor"] + descriptor = ModelDescriptorFactory.get(descriptor_name) + trust_remote_code = descriptor.requires_trust_remote_code() + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) teacher_replacements = None replacement_library_path = raw_metrics["args"].get("replacement_library_path") diff --git a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py index 96d3489f5e..3c00ca212a 100644 --- a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py @@ -21,7 +21,6 @@ from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( ForwardHook, - GptOssRemoveExpertsIndependentHook, NemotronHRemoveExpertsIndependentHook, Qwen3VLRemoveExpertsIndependentHook, RankedChoiceVotingHook, @@ -82,7 +81,6 @@ def supported_hooks(self) -> List[Type[ForwardHook]]: RankedChoiceVotingHookNemotronH, NemotronHRemoveExpertsIndependentHook, Qwen3VLRemoveExpertsIndependentHook, - GptOssRemoveExpertsIndependentHook, ] def prune_single_layer( diff --git a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py index 823f42faf8..b9cfd75faf 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py +++ b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py @@ -95,6 +95,12 @@ def launch_ffn_intermediates_prune_ckpt( def launch_attn_groups_prune_ckpt( cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None ): + descriptor = cfg.descriptor + parent_model_config = load_model_config( + cfg.teacher_dir, trust_remote_code=descriptor.requires_trust_remote_code() + ) + num_attention_heads = parent_model_config.num_attention_heads + for n_heads_in_group in cfg.pruning.n_heads_in_group_list: dirname = f"n_heads_in_group{n_heads_in_group}" @@ -105,7 +111,8 @@ def launch_attn_groups_prune_ckpt( mprint("Process n_heads_in_group {}".format(n_heads_in_group)) mprint(f"=== STARTING ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===") - model_config_overrides_json = {"attention": [{"n_heads_in_group": n_heads_in_group}]} + num_key_value_heads = num_attention_heads // n_heads_in_group + model_config_overrides_json = {"attention": [{"num_key_value_heads": num_key_value_heads}]} mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) @@ -151,7 +158,11 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): ) # Load parent model config to get FFN configuration - parent_model_config = load_model_config(cfg.pruning.model_name_or_path) + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + parent_model_config = load_model_config( + cfg.pruning.model_name_or_path, trust_remote_code=trust_remote_code + ) parent_hidden_size = parent_model_config.hidden_size # Get teacher's FFN configuration diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index 0f5ecd2158..cc81f4f887 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -39,6 +39,10 @@ import pandas as pd from omegaconf import DictConfig +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, BlockConfig, @@ -65,6 +69,7 @@ def build_replacement_library( master_puzzle_dir: Path | str, + descriptor: ModelDescriptor, teacher_checkpoint_dir: Path | str | None = None, add_ffn_no_ops: bool = True, add_attention_no_ops: bool = True, @@ -76,20 +81,22 @@ def build_replacement_library( master_puzzle_dir = Path(master_puzzle_dir) (master_puzzle_dir / "ckpts").mkdir(exist_ok=True) teacher_checkpoint_dir = infer_teacher_dir(master_puzzle_dir, teacher_checkpoint_dir) + trust_remote_code = descriptor.requires_trust_remote_code() subblocks_df = _build_subblocks_df( master_puzzle_dir, teacher_checkpoint_dir, add_ffn_no_ops, add_attention_no_ops, + trust_remote_code=trust_remote_code, ) block_library_df = _build_block_library_from_subblocks(subblocks_df) layer_replacements = _build_layer_replacements( - block_library_df, master_puzzle_dir, teacher_checkpoint_dir + block_library_df, master_puzzle_dir, teacher_checkpoint_dir, trust_remote_code ) single_sequence_replacement_solutions = _build_single_sequence_replacement_solutions( - layer_replacements, teacher_checkpoint_dir + layer_replacements, teacher_checkpoint_dir, trust_remote_code ) json_dump(block_library_df.to_dict(orient="records"), master_puzzle_dir / "block_library.json") @@ -112,11 +119,13 @@ def launch_build_replacement_library(cfg: DictConfig) -> None: f"Build replacement library config: {format_global_config(cfg.build_replacement_library, title='Build replacement library')}" ) + descriptor = ModelDescriptorFactory.get(cfg.descriptor) build_replacement_library( master_puzzle_dir=cfg.puzzle_dir, teacher_checkpoint_dir=cfg.teacher_dir, add_ffn_no_ops=cfg.build_replacement_library.add_ffn_no_ops, add_attention_no_ops=cfg.build_replacement_library.add_attention_no_ops, + descriptor=descriptor, ) @@ -191,9 +200,12 @@ def _build_subblocks_df( teacher_checkpoint_dir: Path | str, add_ffn_no_ops: bool, add_attention_no_ops: bool, + trust_remote_code: bool = False, ) -> pd.DataFrame: teacher_checkpoint_dir = Path(teacher_checkpoint_dir) - checkpoint_dirs = _get_last_checkpoint_from_each_experiment(master_puzzle_dir) + checkpoint_dirs = _get_last_checkpoint_from_each_experiment( + master_puzzle_dir, trust_remote_code=trust_remote_code + ) checkpoint_dirs = [teacher_checkpoint_dir] + list(checkpoint_dirs - {teacher_checkpoint_dir}) checkpoints_to_split = [teacher_checkpoint_dir] @@ -203,7 +215,7 @@ def _build_subblocks_df( if len(subblocks_to_extract) > 0: subblock_rows_from_current_checkpoint = ( _construct_subblock_rows_from_current_checkpoint( - checkpoint_dir, subblocks_to_extract + checkpoint_dir, subblocks_to_extract, trust_remote_code=trust_remote_code ) ) subblock_rows.extend(subblock_rows_from_current_checkpoint) @@ -303,10 +315,10 @@ def _drop_duplicates_of_decomp_no_op(subblocks_df: pd.DataFrame) -> pd.DataFrame def _construct_subblock_rows_from_current_checkpoint( - checkpoint_dir: Path, subblocks_to_extract: list[str] + checkpoint_dir: Path, subblocks_to_extract: list[str], trust_remote_code: bool = False ) -> list[dict[str, Any]]: subblock_rows_from_current_checkpoint = [] - model_config = load_model_config(checkpoint_dir) + model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) for block_idx, block_config in enumerate(model_config.block_configs): for subblock_to_extract in subblocks_to_extract: subblock_row = _init_empty_subblock_row(block_idx) @@ -388,7 +400,9 @@ def _get_rows_with_no_op_subblock( return rows_with_no_op_subblock, subblock_cls -def _get_last_checkpoint_from_each_experiment(master_puzzle_dir: Path | str) -> set[Path]: +def _get_last_checkpoint_from_each_experiment( + master_puzzle_dir: Path | str, trust_remote_code: bool = False +) -> set[Path]: master_puzzle_dir = Path(master_puzzle_dir) master_checkpoints_dir = master_puzzle_dir / CHECKPOINTS_DIR_NAME subdirs_of_master_checkpoints_dir = [ @@ -409,7 +423,11 @@ def _get_last_checkpoint_from_each_experiment(master_puzzle_dir: Path | str) -> ) # Filter out non-DeciLM checkpoints (e.g., unconverted Llama checkpoints) - valid_checkpoint_dirs = [cp for cp in checkpoint_dirs if is_valid_decilm_checkpoint(cp)] + valid_checkpoint_dirs = [ + cp + for cp in checkpoint_dirs + if is_valid_decilm_checkpoint(cp, trust_remote_code=trust_remote_code) + ] experiment_dirs = [ p if (p in subdirs_of_master_checkpoints_dir) else p.parent for p in valid_checkpoint_dirs @@ -465,14 +483,15 @@ def _build_layer_replacements( block_library_df: pd.DataFrame, master_puzzle_dir: Path, teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, ) -> list[dict]: layer_replacements_from_blocks = _build_layer_replacements_from_block_library(block_library_df) layer_replacements_from_checkpoints = _gather_layer_replacements_from_checkpoints( - master_puzzle_dir + master_puzzle_dir, trust_remote_code=trust_remote_code ) layer_replacements = layer_replacements_from_blocks + layer_replacements_from_checkpoints layer_replacements = _filter_duplicate_teacher_replacements( - layer_replacements, teacher_checkpoint_dir + layer_replacements, teacher_checkpoint_dir, trust_remote_code ) return layer_replacements @@ -502,9 +521,13 @@ def _build_layer_replacements_from_block_library(block_library_df: pd.DataFrame) return layer_replacements -def _gather_layer_replacements_from_checkpoints(master_puzzle_dir: str | Path) -> list[dict]: +def _gather_layer_replacements_from_checkpoints( + master_puzzle_dir: str | Path, trust_remote_code: bool = False +) -> list[dict]: gathered_layer_replacements = [] - checkpoint_dirs = _get_last_checkpoint_from_each_experiment(master_puzzle_dir) + checkpoint_dirs = _get_last_checkpoint_from_each_experiment( + master_puzzle_dir, trust_remote_code=trust_remote_code + ) for checkpoint_dir in checkpoint_dirs: if (layer_replacements_path := checkpoint_dir / "replacement_library.json").exists(): layer_replacements = json.loads(layer_replacements_path.read_text()) @@ -523,8 +546,11 @@ def _gather_layer_replacements_from_checkpoints(master_puzzle_dir: str | Path) - def _filter_duplicate_teacher_replacements( layer_replacements: list[dict], teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, ) -> list[dict]: - teacher_model_config = load_model_config(teacher_checkpoint_dir) + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) filtered_layer_replacements = [] for layer_replacement in layer_replacements: if replacement_is_teacher( @@ -537,8 +563,11 @@ def _filter_duplicate_teacher_replacements( def _build_single_sequence_replacement_solutions( layer_replacements: list[dict], teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, ) -> list[dict]: - teacher_model_config = load_model_config(teacher_checkpoint_dir) + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) n_layer = teacher_model_config.num_hidden_layers teacher_replacements = dict() diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py index 7935fea4a0..8a7c2834fd 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -123,10 +123,12 @@ def n_layer(self) -> int: @property def model_config(self) -> DeciLMConfig: if self._model_config is None: + trust_remote_code = self.descriptor.requires_trust_remote_code() self._model_config = load_model_config( self.get_arbitrary_checkpoint_dir(), self.model_config_overrides, ignore_unexpected_config_keys=True, + trust_remote_code=trust_remote_code, ) return self._model_config diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index 2db0bc3916..0b8a3e72fe 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -285,7 +285,8 @@ def calculate_subblock_stats_for_puzzle_dir( teacher_dir = ( Path(teacher_dir) if teacher_dir is not None else master_puzzle_dir / "ckpts" / "teacher" ) - model_config = load_model_config(teacher_dir) + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code) # Get language model config for LM-specific attributes (VL models have nested config) lm_config = descriptor.get_language_model_config(model_config) subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes, model_config) diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index 36e41c4b6a..ecfb8b857b 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -86,7 +86,9 @@ def init_child_from_parent( copy_tokenizer(parent_checkpoint_dir, output_checkpoint_dir) - parent_model_config = load_model_config(parent_checkpoint_dir) + parent_model_config = load_model_config( + parent_checkpoint_dir, trust_remote_code=descriptor.requires_trust_remote_code() + ) parent_state_dict = load_state_dict(parent_checkpoint_dir) # Parse JSON if string @@ -108,6 +110,7 @@ def init_child_from_parent( parent_checkpoint_dir, model_config_overrides=global_config_overrides, ignore_unexpected_config_keys=True, + trust_remote_code=descriptor.requires_trust_remote_code(), ) # Apply block-level overrides if any @@ -126,7 +129,10 @@ def init_child_from_parent( model_class = _get_model_class_from_config(child_model_config) # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() if model_class is AutoModelForCausalLM: - child_model = model_class.from_config(child_model_config, trust_remote_code=True) + trust_remote_code = descriptor.requires_trust_remote_code() + child_model = model_class.from_config( + child_model_config, trust_remote_code=trust_remote_code + ) else: child_model = model_class._from_config(child_model_config) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py index f08b89e449..20c2fbe2ac 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py @@ -135,17 +135,20 @@ def skip_init(module_cls, *args, **kwargs) -> nn.Module: return module -def is_valid_decilm_checkpoint(checkpoint_dir: Path | str) -> bool: +def is_valid_decilm_checkpoint(checkpoint_dir: Path | str, trust_remote_code: bool = False) -> bool: """Validate that a checkpoint is in DeciLM format (has block_configs). Args: checkpoint_dir: Path to checkpoint directory + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. Returns: True if checkpoint is valid DeciLM format, False otherwise """ try: - model_config = load_model_config(checkpoint_dir) + model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) if model_config.block_configs is None: warnings.warn( f"Skipping checkpoint '{checkpoint_dir}' - not in DeciLM format (missing block_configs)" diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 3c3b54830a..3647de5e25 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -73,10 +73,19 @@ def load_checkpoint( checkpoint_dir: Path | str, model_config_overrides: dict | None = None, ignore_unexpected_config_keys: bool = False, + trust_remote_code: bool = False, ) -> DeciLMForCausalLM: """ Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your local repo code, not the code inside the checkpoint. + + Args: + checkpoint_dir: Path to checkpoint directory + model_config_overrides: Optional mapping of config overrides. + ignore_unexpected_config_keys: If True, ignore unexpected config keys. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. """ from modelopt.torch.puzzletron.tools.checkpoint_utils import ( load_state_dict, # prevent circular import @@ -86,7 +95,10 @@ def load_checkpoint( checkpoint_dir = Path(checkpoint_dir) model_config = load_model_config( - checkpoint_dir, model_config_overrides, ignore_unexpected_config_keys + checkpoint_dir, + model_config_overrides=model_config_overrides, + ignore_unexpected_config_keys=ignore_unexpected_config_keys, + trust_remote_code=trust_remote_code, ) # Without sparsity we could have done: @@ -221,7 +233,17 @@ def _save_checkpoint( ) -def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: +def split_checkpoint_to_subblocks( + checkpoint_dir: Path | str, trust_remote_code: bool = False +) -> None: + """Split a checkpoint into subblocks. + + Args: + checkpoint_dir: Path to checkpoint directory + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + """ from modelopt.torch.puzzletron.tools.checkpoint_utils import ( load_state_dict, # prevent circular import ) @@ -229,7 +251,7 @@ def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) - model_config = load_model_config(checkpoint_dir) + model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) state_dict = load_state_dict(checkpoint_dir) save_subblocks(state_dict, checkpoint_dir) diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py index 1cf02dc931..55926eaaea 100644 --- a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -115,7 +115,9 @@ def set_submodule(model: nn.Module, module_name: str, new_submodule: nn.Module) def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtime): - all_block_indexes = set(range(model.config.num_hidden_layers)) + # Get language model config (handles nested configs like Qwen3-VL's text_config) + lm_config = descriptor.get_language_model_config(model.config) + all_block_indexes = set(range(lm_config.num_hidden_layers)) has_first_block = 0 in owned_block_indexes has_last_block = max(all_block_indexes) in owned_block_indexes @@ -136,13 +138,13 @@ def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtim set_submodule( model, descriptor.input_embedding_name(), - DummyWTE(model.config.hidden_size, dtype=runtime.dtype), + DummyWTE(lm_config.hidden_size, dtype=runtime.dtype), ) if not has_last_block: set_submodule(model, descriptor.final_norm_name(), nn.Identity()) if not (model.config.tie_word_embeddings and has_first_block): - set_submodule(model, descriptor.output_embedding_name(), DummyLMHead(model.config)) + set_submodule(model, descriptor.output_embedding_name(), DummyLMHead(lm_config)) return model @@ -202,11 +204,13 @@ def load_and_shard_model( with runtime.device: if model_config is None: - model_config = load_model_config(checkpoint_path) + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(checkpoint_path, trust_remote_code=trust_remote_code) + num_hidden_layers = descriptor.get_language_model_config(model_config).num_hidden_layers if owned_block_indexes == "auto": owned_block_indexes = set( - np.array_split(np.arange(model_config.num_hidden_layers), runtime.world_size)[ + np.array_split(np.arange(num_hidden_layers), runtime.world_size)[ runtime.global_rank ] ) @@ -250,7 +254,7 @@ def load_and_shard_model( # Re-tie weights after load_state_dict with assign=True, which severs the tie. # Needed on first rank (owns embed_tokens) and last rank (owns lm_head). has_first_block = 0 in owned_block_indexes - has_last_block = (model_config.num_hidden_layers - 1) in owned_block_indexes + has_last_block = (num_hidden_layers - 1) in owned_block_indexes if model_config.tie_word_embeddings and (has_first_block or has_last_block): model_shard.tie_weights() @@ -309,7 +313,8 @@ def create_sharded_model( model_class = _get_model_class_from_config(model_config) # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() if model_class is AutoModelForCausalLM: - model = model_class.from_config(model_config, trust_remote_code=True) + trust_remote_code = descriptor.requires_trust_remote_code() + model = model_class.from_config(model_config, trust_remote_code=trust_remote_code) else: model = model_class._from_config(model_config) create_local_shard_( diff --git a/tests/_test_utils/torch/puzzletron/resources/configs/pruning/ffn_pruning.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/pruning/ffn_pruning.yaml deleted file mode 100644 index f0c852eec9..0000000000 --- a/tests/_test_utils/torch/puzzletron/resources/configs/pruning/ffn_pruning.yaml +++ /dev/null @@ -1,12 +0,0 @@ -defaults: - - pruning_defaults - -activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} - -activation_hooks_kwargs: - method: iterative - target_layer: "mlp.down_proj" - layer_input_descriptors_path: - -intermediate_size_list: [256] # teacher_intermediate_size is 14336 -mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/_test_utils/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml deleted file mode 100644 index 0a5eafcfff..0000000000 --- a/tests/_test_utils/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml +++ /dev/null @@ -1,32 +0,0 @@ -defaults: - - /validate_model_defaults - -model_name_or_path: ${teacher_dir} -experiment_id: ${pruning.eval_samples}samples_diverse_mini -activations_log_dir: ??? -activation_hooks_kwargs: ??? - -# Data: -eval_samples: 100 -micro_batch_size: 4 -dataset_path: ${dataset_path} -val_dataset_name: train - -# Prune ckpts -pruned_ckpts_outpt_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} - -## FFN pruning -ffn_list: -mlp_init_mode: "Truncate" - -## KV-heads pruning -n_heads_in_group_list: -gqa_init_mode: "AverageKV" - -## Hidden dimension pruning -hidden_size_list: -hidden_size_init_mode: "PruneByChannelRanking" -linear_init_mode: "FromTeacher" - -mlp_init_config_yaml: - activations_log_dir: ${pruning.activations_log_dir} diff --git a/tests/_test_utils/torch/puzzletron/resources/configs/validate_model_defaults.yaml b/tests/_test_utils/torch/puzzletron/resources/configs/validate_model_defaults.yaml deleted file mode 100644 index 1d042d75df..0000000000 --- a/tests/_test_utils/torch/puzzletron/resources/configs/validate_model_defaults.yaml +++ /dev/null @@ -1,17 +0,0 @@ -model_dtype: torch.bfloat16 -autocast_dtype: torch.bfloat16 -block_size: 8192 -bos_rate: 0.5 -data_column: conversation -val_dataset_name: train -shuffle_seed: 81436 -seed: 42 -fim_rate: 0 -fim_spm_rate: 0 -source_datasets_to_discard: -varlen: false -write_results: false -calc_losses_on_cpu: false -activations_log_dir: -model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/_test_utils/torch/puzzletron/resources/tokenizer/special_tokens_map.json b/tests/_test_utils/torch/puzzletron/resources/tokenizer/special_tokens_map.json deleted file mode 100644 index 02ee80b619..0000000000 --- a/tests/_test_utils/torch/puzzletron/resources/tokenizer/special_tokens_map.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "bos_token": { - "content": "<|begin_of_text|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "eos_token": { - "content": "<|eot_id|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - } -} diff --git a/tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer.json b/tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer.json deleted file mode 100644 index 83592e2494..0000000000 --- a/tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer.json +++ /dev/null @@ -1,212 +0,0 @@ -{ - "version": "1.0", - "truncation": null, - "padding": null, - "added_tokens": [], - "normalizer": null, - "pre_tokenizer": { - "type": "Sequence", - "pretokenizers": [ - { - "type": "Split", - "pattern": { - "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" - }, - "behavior": "Isolated", - "invert": false - }, - { - "type": "ByteLevel", - "add_prefix_space": false, - "trim_offsets": true, - "use_regex": false - } - ] - }, - "post_processor": { - "type": "Sequence", - "processors": [ - { - "type": "ByteLevel", - "add_prefix_space": true, - "trim_offsets": false, - "use_regex": true - }, - { - "type": "TemplateProcessing", - "single": [ - { - "SpecialToken": { - "id": "<|begin_of_text|>", - "type_id": 0 - } - }, - { - "Sequence": { - "id": "A", - "type_id": 0 - } - } - ], - "pair": [ - { - "SpecialToken": { - "id": "<|begin_of_text|>", - "type_id": 0 - } - }, - { - "Sequence": { - "id": "A", - "type_id": 0 - } - }, - { - "SpecialToken": { - "id": "<|begin_of_text|>", - "type_id": 1 - } - }, - { - "Sequence": { - "id": "B", - "type_id": 1 - } - } - ], - "special_tokens": { - "<|begin_of_text|>": { - "id": "<|begin_of_text|>", - "ids": [ - 100 - ], - "tokens": [ - "<|begin_of_text|>" - ] - } - } - } - ] - }, - "decoder": { - "type": "ByteLevel", - "add_prefix_space": true, - "trim_offsets": true, - "use_regex": true - }, - "model": { - "type": "BPE", - "dropout": null, - "unk_token": null, - "continuing_subword_prefix": null, - "end_of_word_suffix": null, - "fuse_unk": false, - "byte_fallback": false, - "ignore_merges": true, - "vocab": { - "!": 0, - "\"": 1, - "#": 2, - "$": 3, - "%": 4, - "&": 5, - "'": 6, - "(": 7, - ")": 8, - "*": 9, - "+": 10, - ",": 11, - "-": 12, - ".": 13, - "/": 14, - "0": 15, - "1": 16, - "2": 17, - "3": 18, - "4": 19, - "5": 20, - "6": 21, - "7": 22, - "8": 23, - "9": 24, - ":": 25, - ";": 26, - "<": 27, - "=": 28, - ">": 29, - "?": 30, - "@": 31, - "A": 32, - "B": 33, - "C": 34, - "D": 35, - "E": 36, - "F": 37, - "G": 38, - "H": 39, - "I": 40, - "J": 41, - "K": 42, - "L": 43, - "M": 44, - "N": 45, - "O": 46, - "P": 47, - "Q": 48, - "R": 49, - "S": 50, - "T": 51, - "U": 52, - "V": 53, - "W": 54, - "X": 55, - "Y": 56, - "Z": 57, - "[": 58, - "\\": 59, - "]": 60, - "^": 61, - "_": 62, - "`": 63, - "a": 64, - "b": 65, - "c": 66, - "d": 67, - "e": 68, - "f": 69, - "g": 70, - "h": 71, - "i": 72, - "j": 73, - "k": 74, - "l": 75, - "m": 76, - "n": 77, - "o": 78, - "p": 79, - "q": 80, - "r": 81, - "s": 82, - "t": 83, - "u": 84, - "v": 85, - "w": 86, - "x": 87, - "y": 88, - "z": 89, - "{": 90, - "|": 91, - "}": 92, - "~": 93, - "¡": 94, - "¢": 95, - "£": 96, - "¤": 97, - "¥": 98, - "¦": 99, - "<|begin_of_text|>": 100, - "<|eot_id|>": 101 - }, - "merges": [] - } -} diff --git a/tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer_config.json b/tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer_config.json deleted file mode 100644 index 754d9e8db5..0000000000 --- a/tests/_test_utils/torch/puzzletron/resources/tokenizer/tokenizer_config.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "bos_token": "<|begin_of_text|>", - "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n", - "clean_up_tokenization_spaces": true, - "eos_token": "<|eot_id|>", - "extra_special_tokens": {}, - "model_input_names": [ - "input_ids", - "attention_mask" - ], - "model_max_length": 131072, - "tokenizer_class": "PreTrainedTokenizer" -} diff --git a/tests/_test_utils/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py b/tests/_test_utils/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py deleted file mode 100644 index aedcae4ab2..0000000000 --- a/tests/_test_utils/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script was used to truncate the tokenizer.json file from Llama 3.1 8B model -to keep only the top 100 most common tokens. -""" - -import json - -# Path to your original and new tokenizer.json -in_path = "./tokenizer.json" -out_path = "./tokenizer_truncated.json" - -# How many top tokens to keep -NUM_TO_KEEP = 100 - -with open(in_path, encoding="utf-8") as f: - tokenizer_data = json.load(f) - -# Get and sort the original vocab by index (frequency proxy) -orig_vocab = tokenizer_data["model"]["vocab"] - -# Sort tokens by their original index (lowest index = assumed most common/important) -sorted_tokens = sorted(orig_vocab.items(), key=lambda item: item[1]) - -# Keep the top N tokens -tokens_to_keep = [tok for tok, idx in sorted_tokens[:NUM_TO_KEEP]] - -# Re-index the selected tokens: 0..N-1 -small_vocab = {tok: i for i, tok in enumerate(tokens_to_keep)} -tokenizer_data["model"]["vocab"] = small_vocab - -# Update vocab size -if "vocab_size" in tokenizer_data["model"]: - tokenizer_data["model"]["vocab_size"] = len(small_vocab) - -# Optionally remove merges if present and unneeded (mostly for BPE/WordPiece) -if "merges" in tokenizer_data["model"]: - tokenizer_data["model"]["merges"] = [] - -# Remove added_tokens if not needed -if "added_tokens" in tokenizer_data: - tokenizer_data["added_tokens"] = [] - -# Write out the truncated tokenizer.json -with open(out_path, "w", encoding="utf-8") as f: - json.dump(tokenizer_data, f, indent=2, ensure_ascii=False) - -print(f"Truncated tokenizer saved to: {out_path}") diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index 07d1565f42..7615c5d085 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -24,18 +24,12 @@ import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers -# Path to HF configs relative to this file -# HF configs are in tests/gpu/torch/puzzletron/resources/hf_configs -HF_CONFIGS_DIR = ( - Path(__file__).parent.parent.parent.parent / "gpu/torch/puzzletron/resources/hf_configs" -) - def setup_test_model_and_data( project_root_path: Path, tmp_path: Path, rank: int, - hf_config_name: str, + hf_model_name: str, hybrid_override_pattern: str | None = None, ) -> tuple[Path, Path, Path]: """ @@ -45,7 +39,7 @@ def setup_test_model_and_data( project_root_path (Path): the root path of the project tmp_path (Path): the temporary path to use for the test rank (int): the rank of the process - hf_config_name (str): Name of the HF config directory (e.g., "llama_3_1_8b_instruct") + hf_model_name (str): HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct") hybrid_override_pattern (str): For NemotronH models, the layer type pattern Returns: @@ -56,10 +50,8 @@ def setup_test_model_and_data( # Register Hydra custom resolvers (needed for config resolution) register_hydra_resolvers() - # The inputs for the nas.convert() step. - # - puzzle_dir = tmp_path / hf_config_name - hf_checkpoint_path = puzzle_dir / f"hf_models/{hf_config_name}" + puzzle_dir = tmp_path / hf_model_name + hf_checkpoint_path = puzzle_dir / f"hf_models/{hf_model_name}" dataset_path = puzzle_dir / "dummy_dataset" if rank == 0: @@ -73,7 +65,7 @@ def setup_test_model_and_data( output_path=str(hf_checkpoint_path), vocab_size=tokenizer.vocab_size, tokenizer=tokenizer, - hf_config_name=hf_config_name, + hf_model_name=hf_model_name, hybrid_override_pattern=hybrid_override_pattern, ) dist.barrier() @@ -89,7 +81,7 @@ def create_and_save_small_hf_model( output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase, - hf_config_name: str, + hf_model_name: str, hybrid_override_pattern: str | None = None, ): """ @@ -101,23 +93,21 @@ def create_and_save_small_hf_model( output_path: Where to save the model vocab_size: Vocabulary size (should match tokenizer) tokenizer: Tokenizer to save alongside the model - hf_config_name: Name of the config directory under resources/hf_configs/ - e.g., "llama_3_1_8b_instruct", "llama_3_2_3b_instruct", or "qwen2_5_7b_instruct" + hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct") hybrid_override_pattern: For NemotronH models, the layer type pattern (e.g., "*-" for Attention+MLP, "M-" for Mamba+MLP). Must match num_hidden_layers. None for non-NemotronH models. """ os.makedirs(output_path, exist_ok=True) # Load real HuggingFace config (preserves tie_word_embeddings, rope_scaling, etc.) - config_path = HF_CONFIGS_DIR / hf_config_name - config = AutoConfig.from_pretrained(config_path, local_files_only=True, trust_remote_code=True) + config = AutoConfig.from_pretrained(hf_model_name, trust_remote_code=True) # Override size-related params to make it small for testing # Note: intermediate_size must be divisible by 256 per DeciLM config requirements # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility # VL models have nested configs (text_config, vision_config) - if hf_config_name == "qwen3-vl-30b-a3b-instruct": + if hasattr(config, "text_config") and hasattr(config, "vision_config"): config.text_config.vocab_size = vocab_size config.text_config.hidden_size = 256 config.text_config.intermediate_size = 512 @@ -160,14 +150,34 @@ def create_and_save_small_hf_model( torch.manual_seed(42) # Create and save the model + # Force CPU initialization for deterministic behavior (prevents NaN on RTX GPUs) + original_cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = "" # TODO: Consider using AutoModel.from_config instead. - if hf_config_name == "qwen3-vl-30b-a3b-instruct": + if hasattr(config, "text_config") and hasattr(config, "vision_config"): from transformers import Qwen3VLMoeForConditionalGeneration model = Qwen3VLMoeForConditionalGeneration._from_config(config) else: model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + # Initialize weights to ensure all parameters are properly initialized + # This prevents NaN values in uninitialized parameters (e.g., backbone.layers.1.mixer.gate.weight + # in nemotron-3-nano-30b-a3b-base-bf16) that can occur with from_config on RTX GPU cards (not on H100) + model.initialize_weights() + + # Fix any remaining NaN/Inf values that initialize_weights() might have missed + for name, param in model.named_parameters(): + if torch.isnan(param).any() or torch.isinf(param).any(): + nan_inf_mask = torch.isnan(param) | torch.isinf(param) + param.data = torch.where(nan_inf_mask, torch.zeros_like(param), param) + + # Restore CUDA_VISIBLE_DEVICES after model creation and initialization + if original_cuda_visible is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible + else: + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + model.to(dtype=torch.bfloat16).save_pretrained(output_path) # Save tokenizer diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index e2373676d2..8a5bad0c62 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -18,7 +18,6 @@ from functools import partial from pathlib import Path -import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data @@ -28,7 +27,6 @@ from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel -@pytest.mark.skip(reason="Temporarily disabled") def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -43,12 +41,10 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" + project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" ) - hydra_config_dir = ( - project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" - ) - hydra_config_name = "llama_3_1_8b_instruct" + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" # # Run the mnt.convert() step @@ -87,7 +83,6 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( dist.cleanup() -@pytest.mark.skip(reason="Temporarily disabled") def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -102,12 +97,10 @@ def _test_nas_convert_attn_pruning_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" - ) - hydra_config_dir = ( - project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" + project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" ) - hydra_config_name = "llama_3_1_8b_instruct-attn-pruning" + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning" # # Run the mnt.convert() step diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index e39f1e1cbc..2af371e5ca 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -17,7 +17,6 @@ from functools import partial from pathlib import Path -import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data @@ -27,7 +26,6 @@ from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel -@pytest.mark.skip(reason="Temporarily disabled") def test_nas_search(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -42,12 +40,10 @@ def _test_nas_search_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" + project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" ) - hydra_config_dir = ( - project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" - ) - hydra_config_name = "llama_3_1_8b_instruct" + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" # # Run the mnt.convert() step diff --git a/tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-ffn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml similarity index 76% rename from tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-ffn-pruning.yaml rename to tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml index 8af352660b..2843f0b97a 100644 --- a/tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-ffn-pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml @@ -1,18 +1,19 @@ +# @package _global_ defaults: - - pruning: ffn_pruning - - scoring: ../validate_solutions_defaults - - realize_model: ../validate_solutions_defaults - - bypass: - - override hydra/hydra_logging: disabled + - /Qwen/Qwen2.5-7B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model - _self_ puzzle_dir: ??? teacher_dir: ${puzzle_dir}/ckpts/teacher/ replacement_library_path: ${puzzle_dir}/replacement_library.json -dataset_path: ??? # path to v0.4_mini +dataset_path: ??? # path to v0.4_mini skip_realize_model: false +descriptor: qwen2 + build_replacement_library: add_ffn_no_ops: true add_attention_no_ops: true @@ -21,15 +22,17 @@ calc_subblock_stats: batch_sizes: [64, 96, 128] prefill_seq_len: 4096 generation_seq_len: 4096 - num_active_tokens_override: # Optional override for sequence lengths + num_active_tokens_override: # Optional override for sequence lengths prefill_queue_size: 0 allocate_prefill_query: false - benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking merge_with_existing_stats: false subblock_stats_filename: "subblock_stats.json" moe_stats_filename: "moe_stats.json" scoring: + descriptor: ${descriptor} + solutions_to_validate: skip_existing_solutions: true @@ -54,6 +57,8 @@ mip: # puzzle_profile: objective: metrics.cosine_embedding_loss_hidden_states bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 subblock_stats_args: - batch_size: 96 @@ -77,18 +82,23 @@ mip: target_memory: 780_000 # 78_000 mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true metric_overrides: + constrain_search_func: max_seconds_per_solution: 60 realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} tokenizer_name: ${to_path:${teacher_dir}} replacement_library_path: ${replacement_library_path} save_models: true - solutions_path: # Filled dynamically + solutions_path: # Filled dynamically # Validate params - skip_validation: false # To enable validation of the model solution set `skip_validation` as False + skip_validation: false # To enable validation of the model solution set `skip_validation` as False eval_samples: 2 micro_batch_size: 1 dataset_path: ${dataset_path}/valid diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..cf6201080c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor diff --git a/tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-attn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml similarity index 76% rename from tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-attn-pruning.yaml rename to tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml index 473a5d418d..cd82a47271 100644 --- a/tests/_test_utils/torch/puzzletron/resources/configs/Llama-3_1-8B-attn-pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml @@ -1,18 +1,19 @@ +# @package _global_ defaults: - - pruning: attn_pruning - - scoring: ../validate_solutions_defaults - - realize_model: ../validate_solutions_defaults - - bypass: - - override hydra/hydra_logging: disabled + - /Qwen/Qwen3-8B/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model - _self_ puzzle_dir: ??? teacher_dir: ${puzzle_dir}/ckpts/teacher/ replacement_library_path: ${puzzle_dir}/replacement_library.json -dataset_path: ??? # path to v0.4_mini +dataset_path: ??? # path to v0.4_mini skip_realize_model: false +descriptor: qwen3 + build_replacement_library: add_ffn_no_ops: true add_attention_no_ops: true @@ -21,15 +22,16 @@ calc_subblock_stats: batch_sizes: [64, 96, 128] prefill_seq_len: 4096 generation_seq_len: 4096 - num_active_tokens_override: # Optional override for sequence lengths + num_active_tokens_override: # Optional override for sequence lengths prefill_queue_size: 0 - allocate_prefill_query: false - benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking merge_with_existing_stats: false subblock_stats_filename: "subblock_stats.json" moe_stats_filename: "moe_stats.json" scoring: + descriptor: ${descriptor} + solutions_to_validate: skip_existing_solutions: true @@ -54,6 +56,8 @@ mip: # puzzle_profile: objective: metrics.cosine_embedding_loss_hidden_states bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 subblock_stats_args: - batch_size: 96 @@ -77,18 +81,23 @@ mip: target_memory: 780_000 # 78_000 mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true metric_overrides: + constrain_search_func: max_seconds_per_solution: 60 realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} tokenizer_name: ${to_path:${teacher_dir}} replacement_library_path: ${replacement_library_path} save_models: true - solutions_path: # Filled dynamically + solutions_path: # Filled dynamically # Validate params - skip_validation: false # To enable validation of the model solution set `skip_validation` as False + skip_validation: false # To enable validation of the model solution set `skip_validation` as False eval_samples: 2 micro_batch_size: 1 dataset_path: ${dataset_path}/valid diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..e6e6ce5bb4 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor.Qwen3_8BFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml new file mode 100644 index 0000000000..00b21ea979 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml @@ -0,0 +1,113 @@ +# @package _global_ +defaults: + - /Qwen/Qwen3-VL-30B-A3B-Instruct/pruning@pruning: expert_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: qwen3_vl + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + - stats.num_local_experts + + human_constraints: + + mip_constraints: + - stats.num_local_experts: 1472 # same constraint as nemotron-3-nano for test consistency + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml new file mode 100644 index 0000000000..81c5f35ba5 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml @@ -0,0 +1,20 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_model_descriptor.Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor + target_name: "mlp" + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.Qwen3VLRemoveExpertsIndependentHook} +activation_hooks_kwargs: + +# num_experts_to_keep must be >= num_experts_per_tok (can't route to more experts than exist) +num_experts_to_keep_list: [8] # num_experts in test model is 16, num_experts_per_tok is 8 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks_mse" + layer_prefix_template: "model.language_model.layers.{layer_idx}.mlp" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml deleted file mode 100644 index 01886607e4..0000000000 --- a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml +++ /dev/null @@ -1,16 +0,0 @@ -defaults: - - pruning_defaults - -activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} - -activation_hooks_kwargs: - method: independent_kv_head_contribution - optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory - target_layer: "self_attn.o_proj" - layer_input_descriptors_path: - -# n_heads_in_group: 4 -# num_attention_heads: 32 # num query heads -# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group -n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] -gqa_init_mode: "PruneKVHeads" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml deleted file mode 100644 index 407c835d8c..0000000000 --- a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml +++ /dev/null @@ -1,15 +0,0 @@ -defaults: - - pruning_defaults - -activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} - -activation_hooks_kwargs: - method: layer_norm_contribution - target_layer: "layernorm" - -# Hidden dimension pruning specific settings -hidden_size_list: [3072, 2048] # Target hidden sizes to prune to -hidden_size_init_mode: "PruneByChannelRanking" -mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher -gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher -linear_init_mode: "FromTeacher" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml deleted file mode 100644 index ec13902379..0000000000 --- a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml +++ /dev/null @@ -1,10 +0,0 @@ -defaults: - - /validate_model_defaults - - _self_ - -solutions_to_validate: -skip_validation: false -save_models: false -bigger_is_better: false -sort_solutions_by: -calculate_full_score_ablations: false diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml new file mode 100644 index 0000000000..57051431a1 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /meta-llama/Llama-3.1-8B-Instruct/pruning@pruning: attn_pruning + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +dataset_path: ??? diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml similarity index 94% rename from tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml rename to tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml index 02c73aca69..8e2e0786b3 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml @@ -1,9 +1,8 @@ +# @package _global_ defaults: - - pruning: attn_pruning - - scoring: ../validate_solutions_defaults - - realize_model: ../validate_solutions_defaults - - bypass: - - override hydra/hydra_logging: disabled + - /meta-llama/Llama-3.1-8B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model - _self_ descriptor: llama diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..6e8af1f651 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/attn_pruning@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaKVHeadsLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..b30f4a17d9 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml similarity index 94% rename from tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml rename to tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml index 65ca64ef4e..78cb6bd73c 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml @@ -1,9 +1,8 @@ +# @package _global_ defaults: - - pruning: ffn_pruning - - scoring: ../validate_solutions_defaults - - realize_model: ../validate_solutions_defaults - - bypass: - - override hydra/hydra_logging: disabled + - /meta-llama/Llama-3.2-3B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model - _self_ descriptor: llama diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..b30f4a17d9 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml new file mode 100644 index 0000000000..e042c4bb62 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml @@ -0,0 +1,112 @@ +# @package _global_ +defaults: + - /mistralai/Mistral-Small-24B-Instruct-2501/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: mistral_small + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..37c21fd638 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml new file mode 100644 index 0000000000..ab2b09e679 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml @@ -0,0 +1,115 @@ +# @package _global_ +defaults: + - /nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning@pruning: expert_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: nemotron_h + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path}/valid + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + - stats.num_local_experts + + human_constraints: + mip_constraints: + - stats.num_local_experts: 1472 # teacher has: 23 moe-blocks * 128 experts = 2944 total experts use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path}/valid + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml new file mode 100644 index 0000000000..4c2335becf --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor.NemotronHExpertRemovalLayerDescriptor + target_name: "mixer" + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.NemotronHRemoveExpertsIndependentHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [96, 64, 32, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks_mse" diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..cb1147d86b --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml @@ -0,0 +1,14 @@ +defaults: + - /pruning/pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml new file mode 100644 index 0000000000..906b7338d8 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml @@ -0,0 +1,113 @@ +# @package _global_ +defaults: + - /nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: nemotron_h_v2 + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..f68068c3ac --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml @@ -0,0 +1,12 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor + +activation_hooks_kwargs: + method: iterative + target_layer: "mixer.down_proj" + layer_input_descriptors_path: diff --git a/tests/_test_utils/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml similarity index 67% rename from tests/_test_utils/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml rename to tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml index 01886607e4..7306b6e379 100644 --- a/tests/_test_utils/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml @@ -1,8 +1,15 @@ defaults: - - pruning_defaults + - /pruning/pruning_defaults@_here_ + - _self_ activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin.KVHeadsPruningMixIn + layer_descriptor: + _target_: ??? + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IndependentKvHeadContributionHook} activation_hooks_kwargs: method: independent_kv_head_contribution optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml similarity index 72% rename from tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml rename to tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml index cad6fcf3ee..7e19afbbce 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml @@ -1,12 +1,13 @@ defaults: - - pruning_defaults + - /pruning/pruning_defaults@_here_ + - _self_ activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} pruning_mixin: _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn layer_descriptor: - _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + _target_: ??? hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} activation_hooks_kwargs: @@ -14,5 +15,5 @@ activation_hooks_kwargs: target_layer: "mlp.down_proj" layer_input_descriptors_path: -intermediate_size_list: [256] # teacher_intermediate_size is 14336 +intermediate_size_list: [256] mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/_test_utils/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml similarity index 93% rename from tests/_test_utils/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml rename to tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml index 407c835d8c..4033fedf3a 100644 --- a/tests/_test_utils/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml @@ -1,5 +1,5 @@ defaults: - - pruning_defaults + - /pruning/pruning_defaults@_here_ activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml similarity index 94% rename from tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml rename to tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml index b24ea1b7cc..f00a86da66 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml @@ -1,12 +1,13 @@ defaults: - - /validate_model_defaults + - /validate_model_defaults@_here_ -descriptor: ${descriptor} model_name_or_path: ${teacher_dir} experiment_id: ${pruning.eval_samples}samples_diverse_mini activations_log_dir: ??? activation_hooks_kwargs: ??? +descriptor: ${descriptor} + # Data: eval_samples: 100 micro_batch_size: 4 diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml similarity index 100% rename from tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml rename to tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml diff --git a/tests/_test_utils/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml similarity index 100% rename from tests/_test_utils/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml rename to tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml diff --git a/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json b/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json deleted file mode 100644 index 0bb6fd75b3..0000000000 --- a/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": [ - 128001, - 128008, - 128009 - ], - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 14336, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.42.3", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index a42a716547..cf600558e5 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -21,6 +21,7 @@ import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.utils.distributed as dist @@ -31,46 +32,30 @@ # using a one-click command. # # Note: Bypass is disabled now in the test. +# + +SEED = 1234 @pytest.mark.parametrize( - ( - "hf_config_name", - "converter", - "hydra_config_subdir", - "hybrid_override_pattern", - "has_moe_layers", - ), + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), [ - ("llama_3_1_8b_instruct", "llama", "llama_3_1_8b_instruct", None, False), - # ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), - # ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), - # ( - # "mistral-small-24b-instruct-2501", - # "mistral_small", - # "mistral-small-24b-instruct-2501", - # None, - # False, - # ), - # ("qwen3-8b", "qwen3", "qwen3-8b", None, False), - # ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), - # ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), - # ( - # "nemotron-3-nano-30b-a3b-base-bf16", - # "nemotron_h", - # "nemotron-3-nano-30b-a3b-base-bf16", - # "*E", - # True, - # ), - # ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), + ("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False), + ("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False), + ("mistralai/Mistral-Small-24B-Instruct-2501", "mistral_small", None, False), + ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", "nemotron_h", "*E", True), + ("nvidia/NVIDIA-Nemotron-Nano-12B-v2", "nemotron_h_v2", "*-", False), + # ("openai/gpt-oss-20b", "gpt_oss", None, True), + ("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False), + ("Qwen/Qwen3-8B", "qwen3", None, False), + ("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True), ], ) def test_puzzletron( project_root_path: Path, tmp_path: Path, - hf_config_name: str, + hf_model_name: str, converter: str, - hydra_config_subdir: str, hybrid_override_pattern: str, has_moe_layers: bool, ): @@ -80,9 +65,8 @@ def test_puzzletron( _test_puzzletron_multiprocess_job, project_root_path, tmp_path, - hf_config_name, + hf_model_name, converter, - hydra_config_subdir, hybrid_override_pattern, has_moe_layers, ), @@ -93,23 +77,25 @@ def test_puzzletron( def _test_puzzletron_multiprocess_job( project_root_path: Path, tmp_path: Path, - hf_config_name: str, + hf_model_name: str, converter: str, - hydra_config_subdir: str, hybrid_override_pattern: str, has_moe_layers: bool, rank: int, size: int, ): + # Set seed BEFORE dist.setup() to ensure reproducibility across all processes + set_seed(SEED) + dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, hf_config_name, hybrid_override_pattern - ) - hydra_config_dir = ( - project_root_path / f"tests/gpu/torch/puzzletron/resources/configs/{hydra_config_subdir}" + project_root_path, tmp_path, rank, hf_model_name, hybrid_override_pattern ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + model_basename = hf_model_name.split("/")[1] + hydra_config_name = f"{hf_model_name}/{model_basename}" # Convert the model using AnyModel converter. if rank == 0: @@ -122,7 +108,7 @@ def _test_puzzletron_multiprocess_job( # Compress the model using a one-click approach puzzletron.puzzletron( - str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) + str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path) ) # @@ -159,16 +145,16 @@ def _test_puzzletron_multiprocess_job( assert (solution_dir / "solutions.json").exists() # Validate lm_loss - _assert_lm_loss(puzzle_dir, hf_config_name) + _assert_lm_loss(puzzle_dir, hf_model_name, tolerance=0.01) else: # assertions for the score_pruning_activations step 1 (FFN pruning) - _assert_score_pruning_activations(puzzle_dir, hf_config_name) + _assert_score_pruning_activations(puzzle_dir, hf_model_name) # assertions for the pruning_ckpts step 2 assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() # assertions for the mip_and_realize_models step 6 - _assert_mip_solutions(puzzle_dir, hf_config_name) + _assert_mip_solutions(puzzle_dir, hf_model_name) # assertions for the build_library_and_stats step 4 assert (puzzle_dir / "replacement_library.json").is_file() @@ -183,7 +169,7 @@ def _test_puzzletron_multiprocess_job( dist.cleanup() print( - f"PYTEST SUMMARY: test_puzzletron({hf_config_name}) test has finished successfully. " + f"PYTEST SUMMARY: test_puzzletron({hf_model_name}) test has finished successfully. " f"Puzzle directory: {puzzle_dir}" ) @@ -191,52 +177,50 @@ def _test_puzzletron_multiprocess_job( # Expected pruning activation values per model # Each model has a list of (score, channels) tuples for each FFN layer EXPECTED_PRUNING_VALUES = { - "llama_3_1_8b_instruct": [ + "meta-llama/Llama-3.1-8B-Instruct": [ {"score": 73, "channels": 95}, {"score": 440, "channels": 174}, ], - "llama_3_2_3b_instruct": [ + "meta-llama/Llama-3.2-3B-Instruct": [ {"score": 79, "channels": 95}, {"score": 428, "channels": 174}, ], - "qwen2_5_7b_instruct": [ - {"score": 96, "channels": 433}, - {"score": 485, "channels": 105}, - ], - # Mistral Small 24B - "mistral-small-24b-instruct-2501": [ + "mistralai/Mistral-Small-24B-Instruct-2501": [ {"score": 73, "channels": 95}, {"score": 431, "channels": 174}, ], - # Qwen3 8B - "qwen3-8b": [ - {"score": 208, "channels": 51}, - {"score": 475, "channels": 266}, - ], # NemotronH with pattern "*-" has only 1 FFN layer (the "-" layer) - "nemotron-nano-12b-v2": [ + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": [ {"score": 70, "channels": 509}, ], - # Note: nemotron-3-nano-30b-a3b-base-bf16 uses MoE expert pruning, not FFN pruning - # so it doesn't have EXPECTED_PRUNING_VALUES + # nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 uses MoE expert pruning, not FFN pruning + "Qwen/Qwen2.5-7B-Instruct": [ + {"score": 96, "channels": 433}, + {"score": 485, "channels": 105}, + ], + "Qwen/Qwen3-8B": [ + {"score": 208, "channels": 51}, + {"score": 475, "channels": 266}, + ], } # Expected lm_loss values per model EXPECTED_LM_LOSS = { - "llama_3_1_8b_instruct": 4.706878662109375, - "llama_3_2_3b_instruct": 4.816886901855469, - "qwen2_5_7b_instruct": 4.778186798095703, - "nemotron-nano-12b-v2": 4.79390811920166, - "mistral-small-24b-instruct-2501": 4.709150314331055, - "qwen3-8b": 4.733874320983887, - "gpt-oss-20b": 4.689250946044922, - "nemotron-3-nano-30b-a3b-base-bf16": 4.741103172302246, - "qwen3-vl-30b-a3b-instruct": 4.65625, + "meta-llama/Llama-3.1-8B-Instruct": 4.706878662109375, + "meta-llama/Llama-3.2-3B-Instruct": 4.816886901855469, + "mistralai/Mistral-Small-24B-Instruct-2501": 4.709150314331055, + # TODO: not reproducible in CI, skipping for now + # "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 4.7737884521484375, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 4.79390811920166, + # "openai/gpt-oss-20b": 4.689250946044922, + "Qwen/Qwen2.5-7B-Instruct": 4.778186798095703, + "Qwen/Qwen3-8B": 4.733874320983887, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 4.65625, } -def _assert_score_pruning_activations(puzzle_dir: Path, hf_config_name: str): +def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): """Assertions for the score_pruning_activations step 1.""" rank = dist.rank() rank_filepath = f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" @@ -245,7 +229,7 @@ def _assert_score_pruning_activations(puzzle_dir: Path, hf_config_name: str): pruning_scores = torch.load(puzzle_dir / rank_filepath) layer_names = list(pruning_scores.keys()) - expected = EXPECTED_PRUNING_VALUES[hf_config_name] + expected = EXPECTED_PRUNING_VALUES[hf_model_name] size = dist.size() if expected is not None: @@ -267,8 +251,8 @@ def _assert_score_pruning_activations(puzzle_dir: Path, hf_config_name: str): ) else: # Print values for new models - update EXPECTED_PRUNING_VALUES with these - print(f"\n=== PRUNING VALUES for {hf_config_name} (num_layers={len(layer_names)}) ===") - print(f'"{hf_config_name}": [') + print(f"\n=== PRUNING VALUES for {hf_model_name} (num_layers={len(layer_names)}) ===") + print(f'"{hf_model_name}": [') for layer_name in layer_names: layer_data = pruning_scores[layer_name] score = layer_data["score"][0].item() @@ -278,7 +262,7 @@ def _assert_score_pruning_activations(puzzle_dir: Path, hf_config_name: str): print("===") -def _assert_lm_loss(puzzle_dir: Path, hf_config_name: str): +def _assert_lm_loss(puzzle_dir: Path, hf_model_name: str, tolerance: float = 0.01): """Validate lm_loss for a model solution.""" solution_0_path = ( puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" @@ -287,19 +271,19 @@ def _assert_lm_loss(puzzle_dir: Path, hf_config_name: str): validation = json.load(f) actual_lm_loss = validation["lm_loss"]["avg"] - expected_lm_loss = EXPECTED_LM_LOSS.get(hf_config_name) + expected_lm_loss = EXPECTED_LM_LOSS.get(hf_model_name) if expected_lm_loss is not None: - assert abs(actual_lm_loss - expected_lm_loss) < 0.01, ( + assert abs(actual_lm_loss - expected_lm_loss) < tolerance, ( f"lm_loss mismatch: expected {expected_lm_loss}, got {actual_lm_loss}" ) else: # Print value for new models - update EXPECTED_LM_LOSS with this - print(f"\n=== LM_LOSS for {hf_config_name} ===") - print(f'"{hf_config_name}": {actual_lm_loss},') + print(f"\n=== LM_LOSS for {hf_model_name} ===") + print(f'"{hf_model_name}": {actual_lm_loss},') print("===") -def _assert_mip_solutions(puzzle_dir: Path, hf_config_name: str): +def _assert_mip_solutions(puzzle_dir: Path, hf_model_name: str): """Assertions for the mip_and_realize_models step.""" mip_dir = puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB" @@ -307,4 +291,4 @@ def _assert_mip_solutions(puzzle_dir: Path, hf_config_name: str): assert (mip_dir / "solutions--checkpoints/solution_0/config.json").exists() # Validate lm_loss - _assert_lm_loss(puzzle_dir, hf_config_name) + _assert_lm_loss(puzzle_dir, hf_model_name) diff --git a/tox.ini b/tox.ini index bcfb41fca3..33700288b8 100644 --- a/tox.ini +++ b/tox.ini @@ -73,6 +73,8 @@ commands = [testenv:cuda13-gpu-puzzletron] commands_pre = # Install deps here so that it gets installed even in --current-env + pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git + pip install --no-build-isolation git+https://github.com/Dao-AILab/causal-conv1d.git pip install -e .[hf,puzzletron,dev-test] commands = # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" From 67999ebac082b1643fb5d6abf6cb1182d0b28f38 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 17 Mar 2026 14:16:39 +0100 Subject: [PATCH 42/62] Dkorzekwa/anymodel gptoss (#1020) ### What does this PR do? Merging dkorzekwa/anymodel_gptoss into dkorzekwa/any_model_other_models - this MR is only for reviewing. Ultimately dkorzekwa/anymodel_gptoss should be merged into feature/puzzletron once dkorzekwa/any_model_other_models is merged there. ## Summary by CodeRabbit * **New Features** * Added support for GPT-OSS, Nemotron V2, Qwen2, and Qwen3 models. * Enabled MXFP4 quantization for GPT-OSS model compression. * Added expert removal pruning for mixture-of-experts (MoE) models. * **Bug Fixes** * Fixed padding token validation to ensure it doesn't exceed vocabulary size. * **Tests** * Enabled test coverage for GPT-OSS model workflows. --------- Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 2 +- .../puzzletron/anymodel/models/__init__.py | 2 +- .../anymodel/models/gpt_oss/__init__.py | 22 + .../models/gpt_oss/gpt_oss_converter.py | 74 +++ .../gpt_oss/gpt_oss_model_descriptor.py | 236 ++++++++ .../models/gpt_oss/gpt_oss_pruned_to_mxfp4.py | 524 ++++++++++++++++++ tests/_test_utils/torch/puzzletron/utils.py | 4 + .../openai/gpt-oss-20b/gpt-oss-20b.yaml | 109 ++++ .../gpt-oss-20b/pruning/expert_removal.yaml | 19 + tests/gpu/torch/puzzletron/test_puzzletron.py | 3 +- 10 files changed, 991 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py create mode 100644 tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 807c1200e6..966aaedd55 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -95,6 +95,7 @@ repos: modelopt/torch/speculative/eagle/utils.py| modelopt/torch/speculative/plugins/transformers.py| modelopt/torch/utils/plugins/megatron_mmlu.py| + modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py| examples/chained_optimizations/bert_prune_distill_quantize.py| examples/deepseek/quantize_to_nvfp4.py| examples/deepseek/ptq.py| @@ -113,7 +114,6 @@ repos: examples/speculative_decoding/server_generate.py| experimental/dms/models/qwen3/configuration_qwen3_dms.py| experimental/dms/models/qwen3/modeling_qwen3_dms.py| - modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py| )$ # Default hook for Apache 2.0 in c/c++/cuda files diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py index 1f3fb477be..34d7ce5e5a 100644 --- a/modelopt/torch/puzzletron/anymodel/models/__init__.py +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # Import models to trigger factory registration -# from modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b import * +from modelopt.torch.puzzletron.anymodel.models.gpt_oss import * from modelopt.torch.puzzletron.anymodel.models.llama import * from modelopt.torch.puzzletron.anymodel.models.mistral_small import * from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py new file mode 100644 index 0000000000..9f72b8dd78 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPT-OSS model support for AnyModel.""" + +from .gpt_oss_converter import GptOssConverter +from .gpt_oss_model_descriptor import GptOssModelDescriptor diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py new file mode 100644 index 0000000000..3e7371aaee --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""GPT-OSS-20B converter for AnyModel compression.""" + +from typing import List + +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("gpt_oss") +class GptOssConverter(Converter): + """Converter for GPT-OSS models to AnyModel format. + + GPT-OSS is a pure MoE model with 32/128 experts per layer and 4/16 active experts. + All layers use MoE FFN (no standard dense FFN layers). + """ + + quantized = "mxfp4" + + @staticmethod + def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: + """Create block configs for GPT-OSS layers. + + GPT-OSS uses MoE for all FFN layers with: + - 32/128 local experts (num_local_experts) + - 4/16 active experts per token (experts_per_token) + - No dense/standard FFN layers + """ + num_hidden_layers = config.num_hidden_layers + num_local_experts = config.num_local_experts + experts_per_token = config.experts_per_token + intermediate_size = config.intermediate_size + + block_configs = [] + for layer_idx in range(num_hidden_layers): + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig( + no_op=False, + intermediate_size=None, # MoE doesn't use this field + moe=MoEConfig( + num_local_experts=num_local_experts, + num_experts_per_tok=experts_per_token, + expert_intermediate_dim=intermediate_size, + ), + ), + ).to_dict() + block_configs.append(block_config) + + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py new file mode 100644 index 0000000000..c77a4547f0 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""GPT-OSS model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Type + +import torch.nn as nn +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssDecoderLayer, GptOssRotaryEmbedding + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) + +# Expert removal is supported for unquantized models (test models). +# Production models use MXFP4 quantized MoE with combined tensors +# (gate_up_proj_blocks, down_proj_blocks), which is not yet supported. +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + + +@ModelDescriptorFactory.register_decorator("gpt_oss") +class GptOssModelDescriptor(ModelDescriptor): + """Model descriptor for GPT-OSS (pure MoE model).""" + + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @classmethod + def create_dummy_block(cls, original_layer: GptOssDecoderLayer, block_index: int) -> nn.Module: + dummy_block = DummyBlock(block_index=block_index) + # Required by `GptOssModel.forward`. + dummy_block.attention_type = original_layer.attention_type + return dummy_block + + @staticmethod + def decoder_layer_cls(): + """Get the decoder layer class for GPT-OSS models. + + GPT-OSS is a standard transformers model in recent versions. + Import directly from transformers.models.gpt_oss.modeling_gpt_oss. + """ + return GptOssDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + """Map BlockConfig to layer constructor overrides.""" + override_kwargs = {} + + if block_config.attention.num_key_value_heads is not None: + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["num_local_experts"] = block_config.ffn.moe.num_local_experts + override_kwargs["num_experts_per_tok"] = block_config.ffn.moe.num_experts_per_tok + + return override_kwargs + + @staticmethod + def attn_no_op_post_init(decoder_layer): + """Replace attention sublayers with no-op modules.""" + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + """Replace MLP sublayers with no-op modules. + + Note: GPT-OSS MoE layers return (hidden_states, router_scores), so we need + to return a tuple of 2 values. + """ + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def init_rotary_embedding(model, runtime): + """Initialize rotary embeddings on the correct device.""" + # GPT-OSS uses RoPE with YARN scaling + + model.model.rotary_emb = GptOssRotaryEmbedding( + config=model.config, + device=runtime.device, + ) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + """Define regex patterns for grouping weights into subblocks.""" + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + """FFN is MoE in GPT-OSS with MXFP4 quantization.""" + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\." + r"(post_attention_layernorm\.weight" + r"|mlp\.router\.weight" + r"|mlp\.router\.bias" + r"|mlp\.experts\.(gate_up_proj|down_proj)(_(bias|blocks|scales))?)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\." + r"(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.q_proj\.bias" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.k_proj\.bias" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.v_proj\.bias" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.o_proj\.bias" + r"|self_attn\.sinks)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + """Return available pruning mixins for GPT-OSS. + + Note: Expert removal works for unquantized models (test models). + Production models use MXFP4 quantization which is not yet supported. + """ + return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())} + + +@dataclass +class GptOssExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + """ + GPT-OSS MoE layer descriptor for expert removal. + + Note: This only works for unquantized models (e.g., test models). + Production GPT-OSS models use MXFP4 quantization with fused experts + (_blocks, _scales, _bias), which requires a different approach. + + Structure: + - Router: mlp.router with .weight and .bias + - Experts: mlp.experts.{idx}.{gate_up_proj,down_proj} with .weight and .bias + """ + + target_name: str = "mlp" + moe_prefix_name: str = "model.layers.{layer_idx}.mlp" + expert_prefix_name: str = "experts" + + # Router has both weight and bias + router_weights: List[str] = field(default_factory=lambda: ["router.weight"]) + router_biases: List[str] = field(default_factory=lambda: ["router.bias"]) + + # Fused format: experts stored as single tensors + is_fused_experts: bool = True + + # Fused format: single tensors containing all experts (test models) + fused_expert_weights: List[str] = field( + default_factory=lambda: [ + "experts.gate_up_proj", + "experts.gate_up_proj_bias", + "experts.down_proj", + "experts.down_proj_bias", + ] + ) + + # Not used for fused format, but kept for compatibility + expert_weights: List[str] = field(default_factory=lambda: ["gate_up_proj", "down_proj"]) + expert_biases: List[str] = field( + default_factory=lambda: ["gate_up_proj_bias", "down_proj_bias"] + ) + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_class_name = "GptOssTopKRouter" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if ( + module_name.endswith(self.target_name) + and module.__class__.__name__ == target_class_name + ): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py new file mode 100644 index 0000000000..64d18921fd --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py @@ -0,0 +1,524 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Create a HuggingFace checkpoint with MXFP4 MoE weights from the original gpt-oss-120b model. + +This script: +1. Copies non-MoE weights from the student model (trained attention, embeddings, etc.) +2. Extracts MoE expert weights from the original gpt-oss-120b in MXFP4 format +3. Deduces expert mappings by comparing weights +4. Outputs a new pruned (heterogeneous) checkpoint with PACKED MXFP4 expert weights +""" + +import argparse +import json +import os +import shutil +from typing import Any, Dict, List, Optional, TextIO, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from tqdm import tqdm +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + + +def deduce_experts_for_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, +) -> Tuple[List[int], int, int]: + """ + Deduce which original experts match the student experts by comparing weights. + + Compares dequantized MXFP4 weights from the original model against the student + model's BF16 weights using L2 distance. Finds the best 1-to-1 matching. + + Args: + layer: Layer index + original_path: Path to original model + original_index: Original model's safetensors index + student_path: Path to student model + num_student_experts: Number of experts in student model (if None, auto-detect) + + Returns: + Tuple of (expert_indices, num_student_experts, num_original_experts) + """ + # Load original tensors + orig_tensors = load_layer_tensors(original_path, layer, original_index) + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + + num_original_experts = mlp1_blocks.shape[0] + + # Load student tensors + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + if not os.path.exists(student_ffn): + print(f"FFN file not found at {student_ffn} - fallback to no_op") + return [], 0, num_original_experts + + student_experts = {} + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + if "experts" in key or "router" in key: + student_experts[key] = f.get_tensor(key) + + # Auto-detect number of student experts + num_student_experts = student_experts[f"model.layers.{layer}.mlp.experts.gate_up_proj"].size(0) + print( + f" Layer {layer}: Comparing {num_student_experts} student experts against {num_original_experts} original experts" + ) + + # Pre-dequantize all original experts once (optimization) + print(f" Pre-dequantizing {num_original_experts} original experts...") + deqexpert_mlp1 = convert_moe_packed_tensors(mlp1_blocks, mlp1_scales).cpu() + deqexpert_mlp2 = convert_moe_packed_tensors(mlp2_blocks, mlp2_scales).cpu() + original_experts_dequant = [] + for orig_idx in range(num_original_experts): + original_experts_dequant.append( + {"up": deqexpert_mlp1[orig_idx], "down": deqexpert_mlp2[orig_idx]} + ) + + # For each student expert, find best matching original expert + experts_to_keep = [] + used_original_indices = set() + + # Number of values to use for quick comparison (tune this) + quick_compare_size = 8 + # Number of candidates to keep for full comparison + top_k_candidates = min(10, num_original_experts) + + for student_idx in range(num_student_experts): + # Get student expert weights + prefix = f"model.layers.{layer}.mlp" + student_up = student_experts.get(f"{prefix}.experts.gate_up_proj")[student_idx] # type: ignore[index] + student_down = student_experts.get(f"{prefix}.experts.down_proj")[student_idx] # type: ignore[index] + + # if student_gate is None or student_up is None or student_down is None: + if student_up is None or student_down is None: + raise ValueError( + f"Missing student expert weights for layer {layer} expert {student_idx}" + ) + + # Step 1: Quick filtering using first N values + candidate_scores = [] + for orig_idx in range(num_original_experts): + if orig_idx in used_original_indices: + continue + + orig_expert = original_experts_dequant[orig_idx] + + up_quick = ( + ( + orig_expert["up"].flatten()[:quick_compare_size] + - student_up.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + down_quick = ( + ( + orig_expert["down"].flatten()[:quick_compare_size] + - student_down.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + + quick_score = (up_quick + down_quick) / 2.0 + candidate_scores.append((orig_idx, quick_score.item())) + + # Step 2: Get top-k candidates based on quick comparison + candidate_scores.sort(key=lambda x: x[1]) + top_candidates = [idx for idx, _ in candidate_scores[:top_k_candidates]] + + # Step 3: Full comparison only on top candidates + best_match_idx = None + best_match_score = float("inf") + + for orig_idx in top_candidates: + orig_expert = original_experts_dequant[orig_idx] + + # Full comparison across all values + up_diff = (orig_expert["up"] - student_up.float()).pow(2).mean().sqrt() + down_diff = (orig_expert["down"] - student_down.float()).pow(2).mean().sqrt() + + score = (up_diff + down_diff) / 2.0 + + if score < best_match_score: + best_match_score = score + best_match_idx = orig_idx + + if best_match_idx is None: + raise ValueError( + f"Could not find match for student expert {student_idx} in layer {layer}" + ) + + experts_to_keep.append(best_match_idx) + used_original_indices.add(best_match_idx) + print( + f" Student expert {student_idx} -> Original expert {best_match_idx} (RMSE: {best_match_score:.6f})" + ) + + return experts_to_keep, num_student_experts, num_original_experts + + +def load_original_index(path: str) -> Dict[str, Any]: + """Load the original model's safetensors index.""" + with open(path, "r") as f: + return json.load(f) + + +def load_layer_tensors(original_path: str, layer: int, index: Dict) -> Dict[str, torch.Tensor]: + """Load all MoE-related tensors for a layer, potentially from multiple files.""" + keys_to_load = [ + f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks", + f"model.layers.{layer}.mlp.experts.gate_up_proj_scales", + f"model.layers.{layer}.mlp.experts.gate_up_proj_bias", + f"model.layers.{layer}.mlp.experts.down_proj_blocks", + f"model.layers.{layer}.mlp.experts.down_proj_scales", + f"model.layers.{layer}.mlp.experts.down_proj_bias", + f"model.layers.{layer}.mlp.router.weight", # Router weight + f"model.layers.{layer}.mlp.router.bias", # Router bias + ] + + # Group by file + file_to_keys = {} + for key in keys_to_load: + if key in index["weight_map"]: + filename = index["weight_map"][key] + if filename not in file_to_keys: + file_to_keys[filename] = [] + file_to_keys[filename].append(key) + + # Load from each file + tensors = {} + for filename, keys in file_to_keys.items(): + filepath = os.path.join(original_path, filename) + with safe_open(filepath, framework="pt") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + + return tensors + + +def copy_non_moe_weights(student_path: str, output_path: str, num_layers: int) -> Dict[str, str]: + """ + Copy non-MoE weights from student model. + Returns weight_map for the new index. + """ + weight_map = {} + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + os.makedirs(subblocks_dir, exist_ok=True) + + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Copy embeddings + src_emb = os.path.join(student_subblocks, "embeddings.safetensors") + dst_emb = os.path.join(subblocks_dir, "embeddings.safetensors") + shutil.copy2(src_emb, dst_emb) + with safe_open(src_emb, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/embeddings.safetensors" + + # Copy lm_head + src_head = os.path.join(student_subblocks, "lm_head.safetensors") + dst_head = os.path.join(subblocks_dir, "lm_head.safetensors") + shutil.copy2(src_head, dst_head) + with safe_open(src_head, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/lm_head.safetensors" + + # Copy attention blocks + for layer in range(num_layers): + src_attn = os.path.join(student_subblocks, f"block_{layer}_attention.safetensors") + dst_attn = os.path.join(subblocks_dir, f"block_{layer}_attention.safetensors") + shutil.copy2(src_attn, dst_attn) + with safe_open(src_attn, framework="pt") as f: + for key in f.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_attention.safetensors" + + return weight_map + + +def process_single_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, + output_path: str, + experts_to_keep: List[int], +) -> Tuple[Dict[str, str], List[str]]: + """ + Process a single layer - loads tensors from potentially multiple files. + Returns (weight_map, verification_errors). + """ + weight_map = {} + verification_errors = [] + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Load all tensors for this layer (may come from multiple files) + orig_tensors = load_layer_tensors(original_path, layer, original_index) + + # Load student FFN file + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + + tensors_to_save = {} + student_tensors = {} + + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if "experts" not in key and "router" not in key: + # Copy norm weights + tensors_to_save[key] = tensor + + # Get router from original model, sliced to kept experts + orig_router_weight = orig_tensors[f"model.layers.{layer}.mlp.router.weight"] + orig_router_bias = orig_tensors[f"model.layers.{layer}.mlp.router.bias"] + + kept_indices_tensor = torch.tensor(experts_to_keep, dtype=torch.long) + sliced_router_weight = orig_router_weight[kept_indices_tensor] + sliced_router_bias = orig_router_bias[kept_indices_tensor] + + tensors_to_save[f"model.layers.{layer}.mlp.router.weight"] = sliced_router_weight + tensors_to_save[f"model.layers.{layer}.mlp.router.bias"] = sliced_router_bias + + # Get MoE tensors + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + mlp1_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] + mlp2_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_bias"] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] = mlp1_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] = mlp1_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] = mlp1_bias[ + kept_indices_tensor + ] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] = mlp2_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_scales"] = mlp2_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_bias"] = mlp2_bias[ + kept_indices_tensor + ] + + # Save the FFN file + output_file = os.path.join(subblocks_dir, f"block_{layer}_ffn.safetensors") + save_file(tensors_to_save, output_file) + + # Build weight map + for key in tensors_to_save.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_ffn.safetensors" + + return weight_map, verification_errors + + +def copy_config_files(student_path: str, output_path: str): + """Copy configuration files from student model and update config.json.""" + files_to_copy = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "chat_template.jinja", + ] + + # Also copy transformers compatibility files + if os.path.exists(student_path): + for f in os.listdir(student_path): + if f.startswith("transformers_"): + files_to_copy.append(f) + + for filename in files_to_copy: + src = os.path.join(student_path, filename) + dst = os.path.join(output_path, filename) + + # Try student path first + if os.path.exists(src): + try: + shutil.copy2(src, dst) + continue + except PermissionError: + pass + + # If we get here, file doesn't exist or permission denied + if not os.path.exists(dst): + print(f" Warning: Could not copy {filename}") + + # Update config.json for DeciGptOssForCausalLM with MXFP4 + src_config = os.path.join(student_path, "config.json") + if not os.path.exists(src_config): + raise FileNotFoundError(f"config.json not found at {src_config}") + + with open(src_config, "r") as f: + config = json.load(f) # type: ignore[arg-type] + + # Set architecture to DeciGptOssForCausalLM for MXFP4 support + config["architectures"] = ["DeciGptOssForCausalLM"] + + # Add quantization_config so vllm calls _load_weights_mxfp4 + config["quantization_config"] = { + "quant_method": "mxfp4", + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ], + } + + dst_config = os.path.join(output_path, "config.json") + with open(dst_config, "w") as f: + json.dump(config, f, indent=2) # type: ignore[arg-type] + + +def main(): + parser = argparse.ArgumentParser(description="Create MXFP4 checkpoint from student model") + parser.add_argument( + "--student-path", type=str, required=True, help="Path to student model checkpoint" + ) + parser.add_argument( + "--original-path", + type=str, + required=True, + help="Path to original gpt-oss-120b model with MXFP4 weights", + ) + parser.add_argument( + "--output-path", type=str, required=True, help="Output path for the new checkpoint" + ) + parser.add_argument("--num-layers", type=int, default=36, help="Number of transformer layers") + args = parser.parse_args() + + print(f"Creating MXFP4 checkpoint...") + print(f" Student model: {args.student_path}") + print(f" Original model: {args.original_path}") + print(f" Output: {args.output_path}") + + # Load original model index + original_index = load_original_index( + os.path.join(args.original_path, "model.safetensors.index.json") + ) + + print("\nDeducing expert mappings by comparing weights...") + experts_to_keep = [] + layer_statistics = [] # Store (num_student, num_original) for each layer + + for layer in range(args.num_layers): + layer_experts, num_student, num_original = deduce_experts_for_layer( + layer, + args.original_path, + original_index, + args.student_path, + ) + experts_to_keep.append(layer_experts) + layer_statistics.append((num_student, num_original)) + + # Print statistics + print(f"\n{'=' * 70}") + print("EXPERT DEDUCTION STATISTICS") + print(f"{'=' * 70}") + print(f"{'Layer':<8} {'Student Experts':<18} {'Original Experts':<18} {'Kept %':<10}") + print(f"{'-' * 70}") + + total_student = 0 + total_original = 0 + for layer, (num_student, num_original) in enumerate(layer_statistics): + percentage = (num_student / num_original * 100) if num_original > 0 else 0 + print(f"{layer:<8} {num_student:<18} {num_original:<18} {percentage:<10.2f}") + total_student += num_student + total_original += num_original + + print(f"{'-' * 70}") + avg_percentage = (total_student / total_original * 100) if total_original > 0 else 0 + print(f"{'TOTAL':<8} {total_student:<18} {total_original:<18} {avg_percentage:<10.2f}") + print(f"{'=' * 70}") + print(f"\n Deduced experts_to_keep mapping for {len(experts_to_keep)} layers") + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + os.makedirs(os.path.join(args.output_path, "subblocks_safetensors"), exist_ok=True) + + # Copy config files + print("Copying configuration files...") + copy_config_files(args.student_path, args.output_path) + + # Save experts_to_keep.json + experts_to_keep_output = os.path.join(args.output_path, "experts_to_keep.json") + with open(experts_to_keep_output, "w") as f: + json.dump(experts_to_keep, f, indent=2) + print(f" Saved experts_to_keep mapping to {experts_to_keep_output}") + + # Copy non-MoE weights (embeddings, attention, lm_head) + print("Copying non-MoE weights...") + weight_map = copy_non_moe_weights(args.student_path, args.output_path, args.num_layers) + + # Load weights per layer (handles multi-file loading) + print(f"Processing {args.num_layers} layers...") + + all_verification_errors = [] + + # Process each layer + for layer in tqdm(range(args.num_layers), desc="Processing layers"): + if len(experts_to_keep[layer]) == 0: + print(f"Layer {layer} has no experts to keep - ffn->no_op") + continue + layer_weight_map, layer_errors = process_single_layer( + layer, + args.original_path, + original_index, + args.student_path, + args.output_path, + experts_to_keep[layer], + ) + weight_map.update(layer_weight_map) + all_verification_errors.extend(layer_errors) + + # Calculate total size + total_size = 0 + subblocks_dir = os.path.join(args.output_path, "subblocks_safetensors") + for filename in os.listdir(subblocks_dir): + filepath = os.path.join(subblocks_dir, filename) + total_size += os.path.getsize(filepath) + + # Create model.safetensors.index.json + index = {"metadata": {"total_size": total_size}, "weight_map": weight_map} + + index_path = os.path.join(args.output_path, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + + print(f"\nCheckpoint created successfully at: {args.output_path}") + print(f"Total size: {total_size / 1e9:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index 7615c5d085..8b7711c3cb 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -146,6 +146,10 @@ def create_and_save_small_hf_model( if hasattr(config, "hybrid_override_pattern") and hybrid_override_pattern is not None: config.hybrid_override_pattern = hybrid_override_pattern + # Ensure pad_token_id is within vocab_size (nn.Embedding requires padding_idx < num_embeddings) + if getattr(config, "pad_token_id", None) is not None and config.pad_token_id >= vocab_size: + config.pad_token_id = 0 + # Set seed for reproducible weight initialization torch.manual_seed(42) diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml new file mode 100644 index 0000000000..2b77516174 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml @@ -0,0 +1,109 @@ +# @package _global_ +defaults: + - /openai/gpt-oss-20b/pruning@pruning: expert_removal # TODO: Note: Works for unquantized test models, not MXFP4 quantized production models + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - bypass: + - override /hydra/hydra_logging: disabled + - _self_ + +descriptor: gpt_oss + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true # TODO: Works for unquantized test models + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + - stats.num_local_experts: 48 # teacher has: 2 layers * 32 experts = 64 total experts + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml new file mode 100644 index 0000000000..4656f1df42 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml @@ -0,0 +1,19 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor + target_name: "mlp.router" +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.RankedChoiceVotingHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index cf600558e5..fa9e5281dc 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -45,7 +45,7 @@ ("mistralai/Mistral-Small-24B-Instruct-2501", "mistral_small", None, False), ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", "nemotron_h", "*E", True), ("nvidia/NVIDIA-Nemotron-Nano-12B-v2", "nemotron_h_v2", "*-", False), - # ("openai/gpt-oss-20b", "gpt_oss", None, True), + ("openai/gpt-oss-20b", "gpt_oss", None, True), ("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False), ("Qwen/Qwen3-8B", "qwen3", None, False), ("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True), @@ -86,7 +86,6 @@ def _test_puzzletron_multiprocess_job( ): # Set seed BEFORE dist.setup() to ensure reproducibility across all processes set_seed(SEED) - dist.setup(timeout=timedelta(10)) # Setup the test model and data. From 660dc17ef3ca40db5a691d9fac4dc0b56cb3f196 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 19 Mar 2026 13:56:03 +0100 Subject: [PATCH 43/62] Merge any_model tutorial (#1035) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Merge any_model tutorial for Puzzletron. ## Summary by CodeRabbit * **New Features** * MIP sweep mode for multi-rate memory-compression searches * Triton-ready HuggingFace deployable, lm-eval adapter, and Ray-compatible inference pathways * Megatron-Bridge distillation CLI/workflow with optional HF export * **New Configurations** * Extensive pruning/memory-sweep profiles for GPT‑Oss, Llama, Mistral, Nemotron, Qwen families * **Documentation** * GptOss guide and conversion example, expanded READMEs, MIP Quick Start, NeMo evaluator notes * **Chores** * Requirements updated (lm-eval bumped; math-verify, ray added) --------- Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 1 + examples/puzzletron/GPTOSS.md | 14 + examples/puzzletron/README.md | 79 +- .../gptoss-20b.yaml | 110 +++ .../gptoss-20b_remove_experts_memory.yaml | 17 + .../pruning/ffn_pruning.yaml | 21 + .../pruning/pruning_defaults.yaml | 34 + .../validate_model_defaults.yaml | 18 + .../validate_solutions_defaults.yaml | 11 + .../Llama-3_1-8B.yaml | 3 + .../llama-3_1-8B_pruneffn_memory.yaml | 5 + .../pruning/ffn_pruning.yaml | 7 + .../pruning/pruning_defaults.yaml | 3 +- .../Llama-3_2-3B.yaml | 110 +++ .../llama-3_2-3B_pruneffn_memory.yaml | 22 + .../pruning/ffn_pruning.yaml | 21 + .../pruning/pruning_defaults.yaml | 33 + .../validate_model_defaults.yaml | 18 + .../validate_solutions_defaults.yaml | 11 + .../Mistral-Small-24B.yaml | 109 +++ ...all-24b-instruct-2501_pruneffn_memory.yaml | 21 + .../pruning/attn_pruning.yaml | 17 + .../pruning/ffn_pruning.yaml | 20 + .../pruning/hidden_dim_pruning.yaml | 16 + .../pruning/pruning_defaults.yaml | 33 + .../validate_model_defaults.yaml | 17 + .../validate_solutions_defaults.yaml | 10 + .../nemotron_nano_12b_v2.yaml | 109 +++ .../nemotron_nano_12b_v2_pruneffn_memory.yaml | 22 + .../pruning/attn_pruning.yaml | 16 + .../pruning/ffn_pruning.yaml | 18 + .../pruning/hidden_dim_pruning.yaml | 15 + .../pruning/pruning_defaults.yaml | 34 + .../validate_model_defaults.yaml | 17 + .../validate_solutions_defaults.yaml | 10 + .../pruning/attn_pruning.yaml | 16 + .../pruning/ffn_pruning.yaml | 18 + .../pruning/hidden_dim_pruning.yaml | 15 + .../pruning/pruning_defaults.yaml | 34 + .../qwen2_5_7b_instruct.yaml | 109 +++ .../qwen2_5_7b_instruct_pruneffn_memory.yaml | 22 + .../validate_model_defaults.yaml | 17 + .../validate_solutions_defaults.yaml | 10 + .../pruning/attn_pruning.yaml | 16 + .../pruning/ffn_pruning.yaml | 18 + .../pruning/hidden_dim_pruning.yaml | 15 + .../pruning/pruning_defaults.yaml | 34 + .../qwen3-8b_pruneffn_memory/qwen3_8b.yaml | 109 +++ .../qwen3_8b_pruneffn_memory.yaml | 22 + .../validate_model_defaults.yaml | 17 + .../validate_solutions_defaults.yaml | 10 + .../evaluation/hf_deployable_anymodel.py | 724 ++++++++++++++++++ .../puzzletron/evaluation/lm_eval_anymodel.py | 115 +++ .../evaluation/nemo_evaluator_instructions.md | 70 ++ examples/puzzletron/main.py | 22 +- .../puzzletron/mbridge_distillation/README.md | 152 ++++ .../mbridge_distillation/distill_hf.py | 326 ++++++++ examples/puzzletron/mip_sweep_example.png | Bin 0 -> 53715 bytes examples/puzzletron/requirements.txt | 4 +- 59 files changed, 2885 insertions(+), 32 deletions(-) create mode 100644 examples/puzzletron/GPTOSS.md create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml create mode 100644 examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml create mode 100644 examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml create mode 100644 examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml create mode 100644 examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml create mode 100644 examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml create mode 100644 examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml create mode 100644 examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml create mode 100644 examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml create mode 100644 examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml create mode 100644 examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml create mode 100644 examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml create mode 100644 examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml create mode 100644 examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml create mode 100644 examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml create mode 100644 examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml create mode 100644 examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml create mode 100644 examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml create mode 100644 examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml create mode 100644 examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml create mode 100644 examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml create mode 100644 examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml create mode 100644 examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml create mode 100644 examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml create mode 100644 examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml create mode 100644 examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml create mode 100644 examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml create mode 100644 examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml create mode 100644 examples/puzzletron/evaluation/hf_deployable_anymodel.py create mode 100644 examples/puzzletron/evaluation/lm_eval_anymodel.py create mode 100644 examples/puzzletron/evaluation/nemo_evaluator_instructions.md create mode 100644 examples/puzzletron/mbridge_distillation/README.md create mode 100644 examples/puzzletron/mbridge_distillation/distill_hf.py create mode 100644 examples/puzzletron/mip_sweep_example.png diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 966aaedd55..546423fa77 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -108,6 +108,7 @@ repos: examples/llm_eval/modeling.py| examples/llm_qat/main.py| examples/llm_sparsity/weight_sparsity/finetune.py| + examples/puzzletron/evaluation/lm_eval_anymodel.py| examples/specdec_bench/specdec_bench/models/specbench_medusa.py| examples/speculative_decoding/main.py| examples/speculative_decoding/medusa_utils.py| diff --git a/examples/puzzletron/GPTOSS.md b/examples/puzzletron/GPTOSS.md new file mode 100644 index 0000000000..7c160c8997 --- /dev/null +++ b/examples/puzzletron/GPTOSS.md @@ -0,0 +1,14 @@ + +## GptOss + +With this release Puzzle algorithm supports only experts removal for `Gpt-Oss`. + +This model comes as a quantized checkpoint i.e. MoE experts matrices are quantized with _MXFP4_ format. +In the pruning steps puzzle utilizes decompressed model (back to BF16) for statistics and scores computation. +This means, during the conversion to puzzle format we decompress the model and store it as a BF16. +Once the pruning is done i.e. experts to be removed are identified and the process is finished, user may want to get back the _MXFP4_ format of the checkpoint. +To do so, there is an additional script, that takes the original and the pruned checkpoint and outputs pruned checkpoint in _MXFP4_ format. + +```bash +python -m modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_pruned_to_mxfp4 --student-path /workspaces/any_model_gpt_oss/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/solution_0/ --original-path /workspaces/source_model_checkpoints/openai_gpt-oss-20b/ --output-path /workspaces/any_model_gpt_oss/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/mxfp4-ckpt/ --num-layers 24 +``` diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index f16162083d..a7e3aedfc1 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -9,18 +9,23 @@ The supported modifications are: To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements. -In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. +In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md). + +> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md). ## Environment -- Install Model-Optimizer in editable mode with the corresponding dependencies: +- Install Model-Optimizer in editable mode with the corresponding dependencies (run from the repo root): ```bash pip install -e .[hf,puzzletron] -pip install -r requirements.txt +pip install -r examples/puzzletron/requirements.txt ``` -- For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use s single GPU. +> **Note:** NeMo containers may ship `nvidia-lm-eval` which may conflict with `lm-eval` that is used for evaluation. +> If so, run `pip uninstall nvidia-lm-eval -y` before installing requirements. + +- For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use a single GPU. - To make use of [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) and [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), you need to accept the terms and conditions for the corresponding model and the dataset in the Huggingface Hub. Log in to the Huggingface Hub and enter your HF token. @@ -133,7 +138,7 @@ This assumes pruning, replacement library building, NAS scoring, and subblock st For example, let's set `target_memory: 96_000` in `llama-3_1-8B_pruneffn_memory.yaml`. ```bash -torchrun --nproc_per_node 2 examples/puzzletron/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress" +torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress" ``` This will generate the following network architecture (see `log.txt`): @@ -195,18 +200,54 @@ block_13: attention no_op ffn intermediate_11520 block_14: attention no_op ffn intermediate_3072 ``` +### MIP Sweep Mode + +The **MIP sweep mode** lets you explore multiple memory compression rates in a single run and compare the accuracy-memory trade-offs. + +#### Quick Start + +1. Enable sweep in your config YAML (e.g., `llama-3_1-8B_pruneffn_memory.yaml`): + + ```yaml + mip: + sweep: + enabled: true + memory_compression_rates: [0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + output_csv: ${puzzle_dir}/mip_sweep_results.csv + ``` + +2. Run the sweep: + + ```bash + torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress" + ``` + +3. View results: The CSV file contains compression rates, memory usage, and accuracy metrics for each configuration. + +#### Example Results + +MIP Sweep Results + +The plot shows how token accuracy changes with different compression rates. Higher compression (0.5 = 50% of original memory) reduces accuracy, while lower compression maintains accuracy closer to the teacher model. + ## Evaluation -Once the model is ready, you can evaluate it using [Language Model Evaluation Harness](https://pypi.org/project/lm-eval/). For example, run the following to evaluate the model on [Massive Multitask Language Understanding](https://huggingface.co/datasets/cais/mmlu) benchmark. +Evaluate AnyModel checkpoints using [lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) directly. ```bash -lm_eval --model hf \ - --model_args pretrained=path/to/model,dtype=bfloat16,trust_remote_code=true,parallelize=True \ - --tasks mmlu \ - --num_fewshot 5 \ - --batch_size 4 +python examples/puzzletron/evaluation/lm_eval_anymodel.py \ + --model hf \ + --model_args pretrained=path/to/checkpoint,dtype=bfloat16,parallelize=True \ + --tasks mmlu \ + --num_fewshot 5 \ + --batch_size 4 ``` +For a quick smoke test, add `--limit 10`. + +> **Alternative:** For server-based evaluation via an OpenAI-compatible endpoint, +> see [evaluation/nemo_evaluator_instructions.md](./evaluation/nemo_evaluator_instructions.md). + ## Inference Performance Benchmarking Now let's evaluate how much speedup we get with the compressed model in terms of throughput and latency. @@ -234,21 +275,9 @@ vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 -- ## Knowledge Distillation -To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. For this, we will use [NeMo framework](https://github.com/NVIDIA-NeMo/NeMo) with the [nemo:25.07](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.07) container. - -First, convert the HF model to NeMo format: +To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. -```bash -python -m nemo_export/convert_hf_to_nemo --input-ckpt-path path/to/HF-model --output-ckpt-path path/to/save/model-nemo -``` - -Now you can utilize all the training features available in NeMo, including distillation. Please refer to the [NeMo distillation documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html). - -[Optional] Once distillation is complete, you can convert the distilled model back to the HuggingFace format. - -```bash -python -m nemo_export/convert_nemo_to_hf --input-ckpt-path path/to/nemo-model --output-ckpt-path path/to/save/model-HF -``` +See [mbridge_distillation/README.md](./mbridge_distillation/README.md) for instructions on using Megatron-Bridge for knowledge distillation. ## Advanced Usage diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml new file mode 100644 index 0000000000..b48f1de78c --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: gpt_oss +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 45_000 + num_params: 3_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml new file mode 100644 index 0000000000..8ed06e9568 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml @@ -0,0 +1,17 @@ +defaults: + - gptoss-20b + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/openai/gpt-oss-20b + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 16_000 # 45 GiB diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..8b19e167d0 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,21 @@ +defaults: + - pruning_defaults + +eval_samples: 2500 #10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor + target_name: "mlp.router" + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.RankedChoiceVotingHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..0eff799d7e --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 10_000 +micro_batch_size: 1 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..b80faea5f5 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml @@ -0,0 +1,18 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ab8c892182 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml @@ -0,0 +1,11 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false + diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml index 7045e0d002..21903db162 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -7,6 +7,7 @@ defaults: - _self_ puzzle_dir: ??? +descriptor: llama teacher_dir: ${puzzle_dir}/ckpts/teacher/ replacement_library_path: ${puzzle_dir}/replacement_library.json dataset_path: ??? # ppath to Nemotron-Post-Training-Dataset-v2 @@ -32,6 +33,7 @@ calc_subblock_stats: backend: trt_torch scoring: + descriptor: ${descriptor} solutions_to_validate: skip_existing_solutions: true @@ -84,6 +86,7 @@ mip: max_seconds_per_solution: 60 realize_model: + descriptor: ${descriptor} teacher_dir: ${to_path:${teacher_dir}} tokenizer_name: ${to_path:${teacher_dir}} replacement_library_path: ${replacement_library_path} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml index 20eec970e7..ad16dbc5ea 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml @@ -15,6 +15,11 @@ puzzle_dir: /workspace/puzzle_dir mip: human_constraints: target_memory: 78_000 # 78 GiB + # Memory sweep configuration (optional) + sweep: + enabled: false + memory_compression_rates: [0.5, 0.6, 0.7, 0.8, 0.9] + output_csv: ${puzzle_dir}/mip_sweep_results.csv # FFN intermediate sizes to search over (heterogeneous architecture) pruning: diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml index 96a8ca72e4..aa857c5ace 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -1,6 +1,13 @@ defaults: - pruning_defaults +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} + activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} activation_hooks_kwargs: diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml index 5d5307b9c7..e05e775bee 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -1,6 +1,7 @@ defaults: - /validate_model_defaults +descriptor: ${descriptor} model_name_or_path: ${teacher_dir} experiment_id: ${pruning.eval_samples}samples_diverse_mini activations_log_dir: ??? @@ -13,7 +14,7 @@ dataset_path: ${dataset_path} val_dataset_name: train # Prune ckpts -pruned_ckpts_outpt_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} ## FFN pruning ffn_list: diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml new file mode 100644 index 0000000000..7de281e788 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: llama +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 45_000 + num_params: 3_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml new file mode 100644 index 0000000000..b5303d318a --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - Llama-3_2-3B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.2-3B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 45_000 # 45 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 8192, so we use proportionally smaller values +pruning: + intermediate_size_list: [2048, 4096, 6144] diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..a58c42c521 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,21 @@ +defaults: + - pruning_defaults + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +# Llama-3.2-3B has intermediate_size=8192, so we use proportionally smaller pruning sizes +intermediate_size_list: [2048, 4096, 6144] +mlp_init_mode: "PruneByActivationsLog" + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..e05e775bee --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..b80faea5f5 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,18 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ab8c892182 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,11 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false + diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml new file mode 100644 index 0000000000..18213f9b7a --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: mistral_small +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 24_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml new file mode 100644 index 0000000000..68a0652d6f --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml @@ -0,0 +1,21 @@ +defaults: + - Mistral-Small-24B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/mistralai/Mistral-Small-24B-Instruct-2501 + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 234_000 # 234 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +pruning: + intermediate_size_list: [8192, 16384, 24576] # teacher_intermediate_size is 32768 diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..cb24e1bc24 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,17 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# Mistral Small 24B: 32 query heads, 8 KV heads +# n_heads_in_group = num_query_heads / num_kv_heads +# num_kv_heads = num_query_heads / n_heads_in_group +# Base: n_heads_in_group = 4, num_kv_heads = 8 +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..0982d90aa8 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,20 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 32768 +intermediate_size_list: [8192, 16384, 24576] +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..7de32621e0 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +# Mistral Small 24B: hidden_size is 5120 +hidden_size_list: [3072, 4096] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..e05e775bee --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml new file mode 100644 index 0000000000..62b6ecb4cb --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: nemotron_h_v2 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 90_000 + num_params: 12_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml new file mode 100644 index 0000000000..3b880b2c7d --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - nemotron_nano_12b_v2 + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/nvidia/Nemotron-Nano-12B-v2 + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 90_000 # 90 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 20480 +pruning: + intermediate_size_list: [4352, 8448, 12544, 16384] diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..60e421b239 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mixer.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..3f7a248ee7 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..6a5922959d --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..af8af990b7 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml new file mode 100644 index 0000000000..aa11499a3c --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: qwen2 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 7_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml new file mode 100644 index 0000000000..fb961033bc --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - qwen2_5_7b_instruct + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/Qwen/Qwen2.5-7B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 78_000 # 78 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 18944 +pruning: + intermediate_size_list: [4096, 7808, 11520, 15104] diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..0b6fa59fbf --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor.Qwen3_8BFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml new file mode 100644 index 0000000000..eec82a7d63 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: qwen3 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 8_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml new file mode 100644 index 0000000000..4ee81286dd --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - qwen3_8b + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/Qwen/Qwen3-8B + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 78_000 # 78 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 12288 +pruning: + intermediate_size_list: [2560, 5120, 7424, 9984] diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/evaluation/hf_deployable_anymodel.py b/examples/puzzletron/evaluation/hf_deployable_anymodel.py new file mode 100644 index 0000000000..3ca8dd7581 --- /dev/null +++ b/examples/puzzletron/evaluation/hf_deployable_anymodel.py @@ -0,0 +1,724 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import json +import logging +from typing import Any + +import numpy as np +import torch +from nemo_deploy import ITritonDeployable +from nemo_deploy.utils import broadcast_list, cast_output, str_ndarray2list +from nemo_export_deploy_common.import_utils import ( + MISSING_TRITON_MSG, + UnavailableError, + null_decorator, +) +from peft import PeftModel +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import ( + resolve_descriptor_from_pretrained, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + +try: + from pytriton.decorators import batch + from pytriton.model_config import Tensor + + HAVE_TRITON = True +except (ImportError, ModuleNotFoundError): + from unittest.mock import MagicMock + + HAVE_TRITON = False + batch = MagicMock() + Tensor = MagicMock() + batch = null_decorator + + +LOGGER = logging.getLogger("NeMo") + +SUPPORTED_TASKS = ["text-generation"] + + +class HuggingFaceLLMDeploy(ITritonDeployable): + """A Triton inference server compatible wrapper for HuggingFace models. + + This class provides a standardized interface for deploying HuggingFace models + in Triton inference server. It supports various NLP tasks and handles model + loading, inference, and deployment configurations. + + Args: + hf_model_id_path (Optional[str]): Path to the HuggingFace model or model identifier. + Can be a local path or a model ID from HuggingFace Hub. + hf_peft_model_id_path (Optional[str]): Path to the PEFT model or model identifier. + Can be a local path or a model ID from HuggingFace Hub. + tokenizer_id_path (Optional[str]): Path to the tokenizer or tokenizer identifier. + If None, will use the same path as hf_model_id_path. + model (Optional[AutoModel]): Pre-loaded HuggingFace model. + tokenizer (Optional[AutoTokenizer]): Pre-loaded HuggingFace tokenizer. + tokenizer_padding (bool): Whether to enable padding in tokenizer. Defaults to True. + tokenizer_truncation (bool): Whether to enable truncation in tokenizer. Defaults to True. + tokenizer_padding_side (str): Which side to pad on ('left' or 'right'). Defaults to 'left'. + task (str): HuggingFace task type (e.g., "text-generation"). Defaults to "text-generation". + **hf_kwargs: Additional keyword arguments to pass to HuggingFace model loading. + """ + + def __init__( + self, + hf_model_id_path: str | None = None, + hf_peft_model_id_path: str | None = None, + tokenizer_id_path: str | None = None, + model: AutoModel | None = None, + tokenizer: AutoTokenizer | None = None, + tokenizer_padding=True, + tokenizer_truncation=True, + tokenizer_padding_side="left", + task: str | None = "text-generation", + torch_dtype: torch.dtype | None = "auto", + device_map: str | None = "auto", + **hf_kwargs, + ): + if not HAVE_TRITON: + raise UnavailableError(MISSING_TRITON_MSG) + + if hf_model_id_path is None and model is None: + raise ValueError("hf_model_id_path or model parameters has to be passed.") + elif hf_model_id_path is not None and model is not None: + LOGGER.warning( + "hf_model_id_path will be ignored and the HuggingFace model set with model parameter will be used." + ) + + assert task in SUPPORTED_TASKS, "Task {} is not a support task.".format(task) + + self.hf_model_id_path = hf_model_id_path + self.hf_peft_model_id_path = hf_peft_model_id_path + self.task = task + self.model = model + self.tokenizer = tokenizer + self.tokenizer_padding = tokenizer_padding + self.tokenizer_truncation = tokenizer_truncation + self.tokenizer_padding_side = tokenizer_padding_side + + if tokenizer_id_path is None: + self.tokenizer_id_path = hf_model_id_path + else: + self.tokenizer_id_path = tokenizer_id_path + + if model is None: + self._load(torch_dtype=torch_dtype, device_map=device_map, **hf_kwargs) + + def _load( + self, torch_dtype: torch.dtype | None = "auto", device_map: str | None = "auto", **hf_kwargs + ) -> None: + """Load the HuggingFace pipeline with the specified model and task. + + This method initializes the HuggingFace AutoModel classes using the provided model + configuration and task type. It handles the model and tokenizer loading + process. + + Args: + torch_dtype (torch.dtype): Data type for the model. Defaults to "auto". + device_map (str): Device map for the model. Defaults to "auto". + **hf_kwargs: Additional keyword arguments to pass to the HuggingFace model loading. + + Raises: + AssertionError: If task is not specified. + """ + assert self.task is not None, "A task has to be given for the generation task." + + if self.task == "text-generation": + # ========================================================================= + # BEGIN ANYMODEL PATCH + # Wraps model loading with deci_x_patcher for heterogeneous layer configs. + # See: modelopt/torch/puzzletron/anymodel/puzzformer/utils.py + # ========================================================================= + + descriptor = resolve_descriptor_from_pretrained( + self.hf_model_id_path, trust_remote_code=hf_kwargs.get("trust_remote_code", False) + ) + + with deci_x_patcher(model_descriptor=descriptor): + self.model = AutoModelForCausalLM.from_pretrained( + self.hf_model_id_path, + torch_dtype=torch_dtype, + device_map=device_map, + **hf_kwargs, + ) + # ========================================================================= + # END ANYMODEL PATCH + # ========================================================================= + + if self.hf_peft_model_id_path is not None: + self.model = PeftModel.from_pretrained(self.model, self.hf_peft_model_id_path) + else: + raise ValueError("Task {} is not supported.".format(self.task)) + num_gpus = torch.cuda.device_count() + # If there is only one GPU, move the model to GPU. If you are using device_map as "auto" or "balanced", + # the model will be moved to GPU automatically. + if device_map is None and num_gpus >= 1 and self.model.device.type != "cuda": + self.model.cuda() + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_id_path, + trust_remote_code=hf_kwargs.pop("trust_remote_code", False), + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + padding_side=self.tokenizer_padding_side, + ) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def generate( + self, + **kwargs: Any, + ) -> list[str]: + """Generate text based on the provided input prompts. + + This method processes input prompts through the loaded pipeline and + generates text according to the specified parameters. + + Args: + **kwargs: Generation parameters including: + - text_inputs: List of input prompts + - max_length: Maximum number of tokens to generate + - num_return_sequences: Number of sequences to generate per prompt + - temperature: Sampling temperature + - top_k: Number of highest probability tokens to consider + - top_p: Cumulative probability threshold for token sampling + - do_sample: Whether to use sampling, default is False for greedy decoding + - echo: Whether to return prompt + generated text (True) or just generated text (False) + - return_full_text: Whether to return full text or only generated part + + Returns: + If output logits and output scores are False: + List[str]: A list of generated texts, one for each input prompt. + If output logits and output scores are True: + Dict: A dictionary containing: + - sentences: List of generated texts + - logits: List of logits + - scores: List of scores + - input_lengths: List of input token lengths (for echo processing) + + Raises: + RuntimeError: If the pipeline is not initialized. + """ + if not self.model: + raise RuntimeError("Model is not initialized") + + inputs = self.tokenizer( + kwargs["text_inputs"], + return_tensors="pt", + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + ) + + # Store input lengths to extract only generated tokens later + input_lengths = [len(input_ids) for input_ids in inputs["input_ids"]] + + # Get echo parameter (default False - only return generated text) + echo = kwargs.pop("echo", False) + kwargs.pop("text_inputs") # Remove text_inputs as it's already been tokenized + + kwargs = {**inputs, **kwargs} + for key, val in kwargs.items(): + if torch.is_tensor(val): + kwargs[key] = val.cuda() + + with torch.no_grad(): + generated_ids = self.model.generate(**kwargs) + return_dict_in_generate = kwargs.get("return_dict_in_generate", False) + if return_dict_in_generate: + # Handle dict output (when logits/scores are requested) + sequences = generated_ids["sequences"] + output = {"sentences": [], "input_lengths": input_lengths, "sequences": sequences} + + if echo: + # Return full text (prompt + generated). + # HF model's generate returns the input/prompt tokens as well by default. + for i, seq in enumerate(sequences): + full_text = self.tokenizer.decode(seq, skip_special_tokens=True) + output["sentences"].append(full_text) + else: + # Extract only the generated tokens (skip input tokens). + # This is required as HF model's generate returns the input/prompt tokens + # as well by default. (return_full_text is specific to some models) + for i, seq in enumerate(sequences): + input_len = input_lengths[i] if i < len(input_lengths) else 0 + generated_tokens = seq[input_len:] # Skip input tokens + generated_text = self.tokenizer.decode( + generated_tokens, skip_special_tokens=True + ) + output["sentences"].append(generated_text) + + if kwargs.get("output_logits", False): + output["logits"] = generated_ids["logits"] + if kwargs.get("output_scores", False): + output["scores"] = generated_ids["scores"] + else: + # Handle list output (normal case) + output = [] + if echo: + # Return full text (prompt + generated), which is the default in case of HF model generate. + for i, seq in enumerate(generated_ids): + full_text = self.tokenizer.decode(seq, skip_special_tokens=True) + output.append(full_text) + else: + # Extract only the generated tokens (skip input tokens) as the default + # behavior returns the input/prompt tokens as well. + for i, seq in enumerate(generated_ids): + input_len = input_lengths[i] if i < len(input_lengths) else 0 + generated_tokens = seq[input_len:] # Skip input tokens + generated_text = self.tokenizer.decode( + generated_tokens, skip_special_tokens=True + ) + output.append(generated_text) + + return output + + def generate_other_ranks(self): + """Generate function for ranks other than the rank 0.""" + while True: + message = torch.empty(1, dtype=torch.long, device="cuda") + torch.distributed.broadcast(message, src=0) + if message == 0: + prompts = broadcast_list(data=[None], src=0) + ( + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ) = broadcast_list(data=[None], src=0) + + return_dict_in_generate = False + if output_logits or output_scores: + return_dict_in_generate = True + + self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + ) + else: + return + + @property + def get_triton_input(self): + inputs = ( + Tensor(name="prompts", shape=(-1,), dtype=bytes), + Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="max_batch_size", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="random_seed", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="output_logits", shape=(-1,), dtype=np.bool_, optional=True), + Tensor(name="output_scores", shape=(-1,), dtype=np.bool_, optional=True), + ) + return inputs + + @property + def get_triton_output(self): + return ( + Tensor(name="sentences", shape=(-1,), dtype=bytes), + Tensor(name="logits", shape=(-1,), dtype=np.single), + Tensor(name="scores", shape=(-1,), dtype=np.single), + ) + + @batch + def triton_infer_fn(self, **inputs: np.ndarray): + output_infer = {} + + try: + prompts = str_ndarray2list(inputs.pop("prompts")) + temperature = inputs.pop("temperature")[0][0] if "temperature" in inputs else 1.0 + top_k = int(inputs.pop("top_k")[0][0] if "top_k" in inputs else 1) + top_p = inputs.pop("top_p")[0][0] if "top_p" in inputs else 0 + num_tokens_to_generate = ( + inputs.pop("max_length")[0][0] if "max_length" in inputs else 256 + ) + output_logits = ( + inputs.pop("output_logits")[0][0] if "output_logits" in inputs else False + ) + output_scores = ( + inputs.pop("output_scores")[0][0] if "output_scores" in inputs else False + ) + return_dict_in_generate = False + if output_logits or output_scores: + return_dict_in_generate = True + + if torch.distributed.is_initialized(): + if torch.distributed.get_world_size() > 1: + torch.distributed.broadcast( + torch.tensor([0], dtype=torch.long, device="cuda"), src=0 + ) + broadcast_list(prompts, src=0) + broadcast_list( + data=[ + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ], + src=0, + ) + + output = self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + echo=False, + ) + + if isinstance(output, dict): + output_infer = {"sentences": cast_output(output["sentences"], np.bytes_)} + + if "scores" in output: + output_scores = [] + for r in output["scores"]: + lp = torch.tensor(r).cpu().detach().numpy() + if len(lp) == 0: + output_scores.append([0]) + else: + output_scores.append(lp) + output_infer["scores"] = np.array(output_scores).transpose(1, 0, 2) + + if "logits" in output: + output_logits = [] + for r in output["logits"]: + lp = torch.tensor(r).cpu().detach().numpy() + if len(lp) == 0: + output_logits.append([0]) + else: + output_logits.append(lp) + output_infer["logits"] = np.array(output_logits).transpose(1, 0, 2) + else: + output_infer = {"sentences": cast_output(output, np.bytes_)} + + except Exception as error: + err_msg = "An error occurred: {}".format(str(error)) + output_infer["sentences"] = cast_output([err_msg], np.bytes_) + + return output_infer + + def _compute_logprobs( + self, + prompts: list[str], + output_infer: dict[str, Any], + compute_logprob: bool, + n_top_logprobs: int, + echo: bool, + ): + """Compute log probabilities and top log probabilities from model scores. + Used by ray_infer_fn to provide OAI API compatible output for evaluations. + + This method processes the raw scores from model generation to compute: + - Log probabilities for chosen tokens + - Top-k log probabilities for each position (if requested) + - Handles both prompt tokens (when echo=True) and generated tokens + + Args: + prompts: List of input prompts + output_infer: Dictionary containing model outputs including scores, sequences, and input_lengths + compute_logprob: Whether to compute log probabilities + n_top_logprobs: Number of top log probabilities to return (0 to disable) + echo: Whether to include prompt token log probabilities + + Returns: + Tuple[Optional[List], Optional[List]]: + - log_probs_list: List of log probabilities for each sample (None if not computed) + - top_logprobs_list: List of top-k log probabilities for each sample (None if not computed) + """ + # Tokenize the prompts to get prompt token IDs (needed for echo) + prompt_token_ids = None + prompt_inputs = None + if echo: + prompt_inputs = self.tokenizer( + prompts, + return_tensors="pt", + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + ) + prompt_token_ids = prompt_inputs["input_ids"] + # Move to same device as model + for key, val in prompt_inputs.items(): + if torch.is_tensor(val): + prompt_inputs[key] = val.cuda() + + # Process each sample + log_probs_list = [] + top_logprobs_list = [] + + for sample_idx in range(len(prompts)): + sample_log_probs = [] + sample_top_logprobs = [] + + # Get the generated sequence for this sample + sequences = output_infer["sequences"][sample_idx] + + # For echo, compute prompt token logprobs by running forward pass + if echo and prompt_token_ids is not None: + prompt_len = len(prompt_token_ids[sample_idx]) + + # Run forward pass on prompt to get logits for prompt tokens as scores in output_infer contains + # logits only for generated tokens. + with torch.no_grad(): + # Create input for this specific sample + sample_prompt_input = { + key: val[sample_idx : sample_idx + 1] for key, val in prompt_inputs.items() + } + prompt_outputs = self.model(**sample_prompt_input) + prompt_logits = prompt_outputs.logits[0] # Shape: [seq_len, vocab_size] + + # Calculate log probs for each prompt token (except the first BOS token) + for token_pos in range(1, prompt_len): # Start from 1 to skip BOS + # The logit at position i-1 predicts token at position i + logit_for_current_token = prompt_logits[token_pos - 1] + current_token_id = prompt_token_ids[sample_idx][token_pos].item() + + # Calculate log probabilities + log_probs = torch.nn.functional.log_softmax(logit_for_current_token, dim=-1) + chosen_log_prob = log_probs[current_token_id].item() + sample_log_probs.append(chosen_log_prob) + + # Calculate top log probabilities if requested + if n_top_logprobs > 0: + top_log_probs_dict = {} + top_k_values, top_k_indices = torch.topk( + log_probs, min(n_top_logprobs, len(log_probs)) + ) + for k_idx in range(len(top_k_indices)): + token_id = top_k_indices[k_idx].item() + token_str = self.tokenizer.decode([token_id]) + top_log_probs_dict[token_str] = top_k_values[k_idx].item() + sample_top_logprobs.append(top_log_probs_dict) + + # Process the scores for generated tokens + for token_idx, score_tensor in enumerate(output_infer["scores"]): + # Get the chosen token ID from the sequence + # Scores start after the prompt, so we need to offset + input_len = ( + output_infer.get("input_lengths", [0])[sample_idx] + if "input_lengths" in output_infer + else 0 + ) + seq_idx = input_len + token_idx + + if seq_idx < len(sequences): + chosen_token_id = ( + sequences[seq_idx].item() + if hasattr(sequences[seq_idx], "item") + else sequences[seq_idx] + ) + + # Calculate log probabilities + log_probs = torch.nn.functional.log_softmax(score_tensor[sample_idx], dim=-1) + chosen_log_prob = log_probs[chosen_token_id].item() + sample_log_probs.append(chosen_log_prob) + + # Calculate top log probabilities if requested + if n_top_logprobs > 0: + top_log_probs_dict = {} + top_k_values, top_k_indices = torch.topk( + log_probs, min(n_top_logprobs, len(log_probs)) + ) + for k_idx in range(len(top_k_indices)): + token_id = top_k_indices[k_idx].item() + token_str = self.tokenizer.decode([token_id]) + top_log_probs_dict[token_str] = top_k_values[k_idx].item() + sample_top_logprobs.append(top_log_probs_dict) + + log_probs_list.append(sample_log_probs) + if n_top_logprobs > 0: + top_logprobs_list.append(sample_top_logprobs) + + # Return log probs and top logprobs + return_log_probs = log_probs_list if compute_logprob else None + return_top_logprobs = top_logprobs_list if n_top_logprobs > 0 else None + + return return_log_probs, return_top_logprobs + + def ray_infer_fn(self, inputs: dict[Any, Any]): + """Perform inference using Ray with dictionary inputs and outputs. + + Args: + inputs (Dict[Any, Any]): Dictionary containing input parameters: + - prompts: List of input prompts + - temperature: Sampling temperature (optional) + - top_k: Number of highest probability tokens to consider (optional) + - top_p: Cumulative probability threshold for token sampling (optional) + - max_tokens: Maximum number of tokens to generate (optional) + - compute_logprob: Whether to compute log probabilities (optional) + - n_top_logprobs: Number of top log probabilities to return (optional) + - echo: Whether to echo the prompt in output (optional) + + Returns: + Dict[str, Any]: Dictionary containing: + - sentences: List of generated texts + - log_probs: Optional list of log probabilities if compute_logprob is True + - top_logprobs: Optional list of top log probabilities if n_top_logprobs > 0 + """ + try: + prompts = inputs.pop("prompts") + temperature = inputs.pop("temperature", 1.0) + top_k = int(inputs.pop("top_k", 1)) + top_p = inputs.pop("top_p", 0.0) + num_tokens_to_generate = inputs.pop("max_tokens", 256) + output_logits = inputs.pop("output_logits", False) + output_scores = inputs.pop("output_scores", False) + compute_logprob = inputs.pop("compute_logprob", False) + n_top_logprobs = inputs.pop("n_top_logprobs", 0) + echo = inputs.pop("echo", False) + + output_infer = self._infer_fn_ray( + prompts=prompts, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_tokens_to_generate=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + compute_logprob=compute_logprob, + n_top_logprobs=n_top_logprobs, + echo=echo, + ) + # Code to get logprobs (required in OAI API format for eval) from the scores in output_infer. + if ( + (compute_logprob or n_top_logprobs > 0) + and "scores" in output_infer + and output_infer["scores"] + ): + log_probs_list, top_logprobs_list = self._compute_logprobs( + prompts=prompts, + output_infer=output_infer, + compute_logprob=compute_logprob, + n_top_logprobs=n_top_logprobs, + echo=echo, + ) + + # Add to output + if log_probs_list is not None: + output_infer["log_probs"] = log_probs_list + if top_logprobs_list is not None: + # Convert to JSON strings for compatibility + output_infer["top_logprobs"] = [ + json.dumps(top_logprobs) for top_logprobs in top_logprobs_list + ] + + # Remove raw outputs that are not needed in the final response + output_infer.pop("scores", None) + output_infer.pop("sequences", None) + output_infer.pop("input_lengths", None) + return output_infer + except Exception as error: + err_msg = "An error occurred: {}".format(str(error)) + return {"sentences": [err_msg]} + + def _infer_fn_ray( + self, + prompts, + temperature=1.0, + top_k=1, + top_p=0.0, + num_tokens_to_generate=256, + output_logits=False, + output_scores=False, + compute_logprob=False, + n_top_logprobs=0, + echo=False, + cast_output_func=None, + ): + """Common internal function for inference operations. + + Args: + prompts: List of input prompts + temperature: Sampling temperature + top_k: Number of highest probability tokens to consider + top_p: Cumulative probability threshold for token sampling + num_tokens_to_generate: Maximum number of tokens to generate + output_logits: Whether to output logits + output_scores: Whether to output scores + compute_logprob: Whether to compute log probabilities + n_top_logprobs: Number of top log probabilities to return + echo: Whether to echo the prompt in output + cast_output_func: Optional function to cast output values + + Returns: + Dict containing inference results with raw outputs + """ + # Enable return_dict if we need scores for logprobs or if output_logits/scores are requested + return_dict_in_generate = ( + output_logits or output_scores or compute_logprob or n_top_logprobs > 0 + ) + # Enable output_scores if we need to compute logprobs. scores and logits from generate are both identical in + # case of greedy decoding. Hence setting output_scores to True when compute_logprob or n_top_logprobs > 0. + if compute_logprob or n_top_logprobs > 0: + output_scores = True + + if torch.distributed.is_initialized(): + if torch.distributed.get_world_size() > 1: + torch.distributed.broadcast( + torch.tensor([0], dtype=torch.long, device="cuda"), src=0 + ) + broadcast_list(prompts, src=0) + broadcast_list( + data=[ + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ], + src=0, + ) + + output = self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + echo=echo, + ) + + if isinstance(output, dict): + return output + + else: + return {"sentences": output} diff --git a/examples/puzzletron/evaluation/lm_eval_anymodel.py b/examples/puzzletron/evaluation/lm_eval_anymodel.py new file mode 100644 index 0000000000..7f9e07dd2b --- /dev/null +++ b/examples/puzzletron/evaluation/lm_eval_anymodel.py @@ -0,0 +1,115 @@ +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c + +# MIT License +# +# Copyright (c) 2020 EleutherAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +"""Run lm-eval directly on AnyModel (Puzzletron) checkpoints without a deployment server. + +Patches lm-eval's HFLM to wrap model loading with deci_x_patcher so AnyModel +Puzzletron checkpoints load correctly. Model descriptor is auto-detected from the +checkpoint's config.json model_type. +""" + +from lm_eval import utils +from lm_eval.__main__ import cli_evaluate +from lm_eval.api.model import T +from lm_eval.models.huggingface import HFLM + +# Trigger factory registration for all model descriptors +import modelopt.torch.puzzletron.anymodel.models # noqa: F401 +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import ( + resolve_descriptor_from_pretrained, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + +def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: + """Override HFLM.create_from_arg_obj to wrap model loading with deci_x_patcher.""" + + additional_config = {} if additional_config is None else additional_config + additional_config = {k: v for k, v in additional_config.items() if v is not None} + + pretrained = arg_dict.get("pretrained") + descriptor = resolve_descriptor_from_pretrained( + pretrained, trust_remote_code=arg_dict.get("trust_remote_code", False) + ) + # The patcher must be active during HFLM.__init__ because that's where + # AutoModelForCausalLM.from_pretrained() is called internally. + with deci_x_patcher(model_descriptor=descriptor): + model_obj = cls(**arg_dict, **additional_config) + + return model_obj + + +def create_from_arg_string( + cls: type[T], arg_string: str, additional_config: dict | None = None +) -> T: + """Create an LM instance from a comma-separated argument string. + + Args: + arg_string: Arguments as ``"key1=value1,key2=value2"``. + additional_config: Extra configuration merged into the parsed args. + + Returns: + An instance of this LM subclass. + """ + args = utils.simple_parse_args_string(arg_string) + additional_config = {} if additional_config is None else additional_config + args2 = {k: v for k, v in additional_config.items() if v is not None} + + pretrained = args.get("pretrained") + descriptor = resolve_descriptor_from_pretrained( + pretrained, trust_remote_code=args.get("trust_remote_code", False) + ) + + # The patcher must be active during HFLM.__init__ because that's where + # AutoModelForCausalLM.from_pretrained() is called internally. + with deci_x_patcher(model_descriptor=descriptor): + model_obj = cls(**args, **args2) + + return model_obj + + +# Monkey-patch HFLM so lm-eval uses our patched model loading +HFLM.create_from_arg_obj = classmethod(create_from_arg_obj) +HFLM.create_from_arg_string = classmethod(create_from_arg_string) + + +if __name__ == "__main__": + cli_evaluate() diff --git a/examples/puzzletron/evaluation/nemo_evaluator_instructions.md b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md new file mode 100644 index 0000000000..f8b53889c6 --- /dev/null +++ b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md @@ -0,0 +1,70 @@ +# Evaluation with NeMo Evaluator (Alternative) + +> **Recommended approach:** Use lm-eval for direct evaluation without a +> deployment server. See the main [README](../README.md#evaluation) for details. + +Evaluate AnyModel checkpoints by deploying a local OpenAI-compatible completions endpoint and running benchmarks against it. + +This flow requires Ray for serving the model and NeMo Export-Deploy (included in NeMo containers): + +```bash +pip install -r examples/puzzletron/requirements.txt +``` + +**1. Deploy the model (2 GPUs example):** + +We need to patch the `hf_deployable.py` script from Export-Deploy. Best way is to do it as a mount in docker run: + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer # or set to your local Model-Optimizer repository path if you have cloned it +if [ ! -d "${MODELOPT_DIR}" ]; then + git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +fi + +export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02 +docker run \ + --gpus all \ + --shm-size=16GB \ + --net=host \ + --ulimit memlock=-1 \ + --rm -it \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -v ${MODELOPT_DIR}/examples/puzzletron/evaluation/hf_deployable_anymodel.py:/opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py \ + -w /opt/Model-Optimizer/examples/megatron_bridge \ + ${DOCKER_IMAGE} bash +``` + +Alternatively you can manually update the file + +```bash +# Install the AnyModel-patched deployable (first time only: backs up the original) +# /opt/Export-Deploy is the default path in NeMo containers — adjust if needed +cp /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py.bak +cp examples/puzzletron/evaluation/hf_deployable_anymodel.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py +``` + +Now start ray server and deploy the model + +```bash +# Start the server (blocks while running — use a separate terminal) +ray start --head --num-gpus 2 --port 6379 --disable-usage-stats +python /opt/Export-Deploy/scripts/deploy/nlp/deploy_ray_hf.py \ + --model_path path/to/checkpoint \ + --model_id anymodel-hf \ + --num_gpus 2 --num_gpus_per_replica 2 --num_cpus_per_replica 16 \ + --trust_remote_code --port 8083 --device_map "auto" --cuda_visible_devices "0,1" +``` + +**2. Run MMLU:** + +```bash +eval-factory run_eval \ + --eval_type mmlu \ + --model_id anymodel-hf \ + --model_type completions \ + --model_url http://0.0.0.0:8083/v1/completions/ \ + --output_dir examples/puzzletron/evals/mmlu_anymodel +``` + +For a quick debug run, add `--overrides "config.params.limit_samples=5"`. diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py index 16d4de385e..5bb04818e5 100644 --- a/examples/puzzletron/main.py +++ b/examples/puzzletron/main.py @@ -16,9 +16,10 @@ """ Main script for running the puzzletron algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146). -This script provides two modes: +This script provides three modes: 1. Default mode: Runs the full puzzletron pipeline 2. MIP-only mode: Runs only the MIP search and realize models phase +3. MIP sweep mode: Runs MIP for multiple memory compression rates (enabled via config) Usage: # Full puzzletron pipeline @@ -26,6 +27,9 @@ # Only MIP search and realize models phase torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only + + # MIP sweep mode (set mip.sweep.enabled: true in config) + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only """ import argparse @@ -34,6 +38,7 @@ import modelopt.torch.nas as mtn import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch.puzzletron.mip.sweep as sweep import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel from modelopt.torch.puzzletron.tools.hydra_utils import ( @@ -143,10 +148,17 @@ def run_mip_only(hydra_config_path: str): overrides=[], ) - # mip_and_realize_models (distributed processing) - # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + # Check if sweep mode is enabled + if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): + mprint( + "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + ) + sweep.run_mip_sweep(hydra_cfg) + else: + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API + mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) dist.cleanup() mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md new file mode 100644 index 0000000000..d3420be096 --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -0,0 +1,152 @@ +# Knowledge Distillation with Megatron-Bridge + +This guide shows how to perform knowledge distillation on Puzzletron-compressed AnyModel checkpoints using Megatron-Bridge. + +## Overview + +1. Set up the environment with Megatron-Bridge +2. Prepare tokenized dataset +3. Run knowledge distillation training directly from HuggingFace checkpoints +4. Review MMLU evaluation results (before/after distillation) + +## Setup + +**Clone Model-Optimizer repo:** + +The NeMo container does not include Model-Optimizer examples, so you need to clone the Model-Optimizer repo: + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer +git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +``` + +**Start Docker container:** + +Use the [NeMo 26.02.01 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02.01): + +```bash +# Recommended to mount a workspace directory for storing datasets and distilled models +docker run --gpus all -it --rm \ + -v /path/to/your/project:/workspace \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -w /opt/Model-Optimizer \ + nvcr.io/nvidia/nemo:26.02.01 \ + /bin/bash +``` + +## Dataset Preparation + +This section describes how to prepare datasets for knowledge distillation. We provide examples using WikiText-103, which is a small dataset that can still produce decent results (see the Qwen3-8B example below showing +10.11 percentage point improvement). For production use, larger datasets like [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) are recommended. + +### Download and Tokenize Dataset + +Download and tokenize the dataset in a single step. This downloads the dataset from HuggingFace, tokenizes it, and saves it in the Megatron format (`.bin` and `.idx` files): + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --hf_dataset Salesforce/wikitext \ + --hf_name wikitext-103-v1 \ + --hf_split train \ + --output_dir path/to/hf_datasets/wikitext-103-v1 \ + --tokenizer meta-llama/Llama-3.1-8B-Instruct \ + --json_keys text \ + --workers 32 +``` + +This will create: + +- `Salesforce--wikitext_wikitext-103-v1_train_text_document.bin` - Binary tokenized data +- `Salesforce--wikitext_wikitext-103-v1_train_text_document.idx` - Index file for the binary data +- `Salesforce--wikitext_wikitext-103-v1_train_text_document/cache/` - Cache directory (created after running distillation) + +## Run Knowledge Distillation + +Run distillation directly from HuggingFace checkpoints (student and teacher) with tokenized dataset: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.py \ + --student_hf_path /path/to/student/huggingface/checkpoint \ + --teacher_hf_path /path/to/teacher/huggingface/checkpoint \ + --data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \ + --output_dir /path/to/distilled/checkpoint \ + --hf-export-path /path/to/exported/hf/model \ + --hf-model meta-llama/Llama-3.1-8B-Instruct \ + --seq_length 4096 \ + --tp_size 8 \ + --pp_size 1 \ + --mbs 1 \ + --gbs 4 \ + --train_iters 100 \ + --lr 0.0001 \ + --min_lr 1e-05 \ + --lr_warmup_iters 10 \ + --eval_interval 10 \ + --eval_iters 10 \ + --log_interval 1 +``` + +**Notes:** + +- Add `--trust_remote_code` if student or teacher checkpoints need HuggingFace custom modeling code. +- The distilled Megatron-Bridge checkpoint will be saved to `--output_dir/checkpoints/iter_`. +- Add `--hf-export-path` to automatically export the final checkpoint to HuggingFace format after distillation. When using `--hf-export-path`, you must also provide `--hf-model` to specify the HuggingFace model ID to use as a template for export (e.g., `meta-llama/Llama-3.1-8B-Instruct`). The `--hf-model` should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation). +- For production use, use larger datasets like [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) and train for more iterations. See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices. + +## MMLU Evaluation Results + +This section presents MMLU evaluation results for knowledge distillation experiments compressing Qwen3-8B and Llama-3.1-8B-Instruct. + +### Successful Case: Qwen3-8B (80% of original) + +Distillation results for a memory-compressed Qwen3-8B checkpoint (80% of original size): + +| Model | MMLU | Humanities | Other | Social Sci | STEM | +|-------|------|------------|-------|------------|------| +| 80% pre-distillation | 0.5910 | 0.5046 | 0.6363 | 0.6831 | 0.5855 | +| 80% post-distillation | 0.6921 | 0.5906 | 0.7316 | 0.7975 | 0.7016 | +| Original Qwen3-8B | 0.7493 | 0.6648 | 0.7856 | 0.8385 | 0.7526 | + +**Key observations:** + +- MMLU accuracy improved from 59.10% to 69.21% (+10.11 percentage points) after distillation +- Achieved with just 100 iterations on WikiText-103, demonstrating efficient knowledge transfer +- Recovery of 64% of the gap to the teacher model (from 59.10% to 69.21%, closing 64% of the gap from 59.10% to 74.93%) +- All individual category scores (Humanities, Other, Social Sciences, STEM) improved significantly + +### Successful Case: Llama-3.1-8B-Instruct (50% of original, 56,810 MiB) + +Distillation results for a pruned Llama-3.1-8B-Instruct checkpoint (50% of original size, 56,810 MiB memory constraint): + +| Model | MMLU | Humanities | Other | Social Sciences | STEM | +|-------|------|------------|-------|-----------------|------| +| Before distillation | 0.2316 | 0.2462 | 0.2292 | 0.2250 | 0.2274 | +| After distillation | 0.2960 | 0.3146 | 0.3085 | 0.2925 | 0.2768 | +| Original Llama-3.1-8B-Instruct | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 | + +**Key observations:** + +- MMLU accuracy (average across all categories) improved from 23.16% to 29.60% (+6.44 percentage points) +- All individual category scores (Humanities, Other, Social Sciences, STEM) improved, demonstrating effective knowledge transfer from teacher to student + +### Regression Case: Llama-3.1-8B-Instruct (69% of original, 78,000 MiB) + +Distillation results for a pruned Llama-3.1-8B-Instruct checkpoint (approximately 69% of original size, 78,000 MiB memory constraint) showing regression due to overfitting on the small WikiText-103 dataset (evaluated with limit 100): + +| Model | MMLU | Humanities | Other | Social Sciences | STEM | +|-------|------|------------|-------|-----------------|------| +| Before distillation | 0.6626 | 0.7069 | 0.6892 | 0.7525 | 0.5574 | +| After distillation | 0.6496 | 0.6862 | 0.6677 | 0.7433 | 0.5532 | +| Original Llama-3.1-8B-Instruct | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 | + +**Key observations:** + +- MMLU accuracy (average across all categories) decreased from 66.26% to 64.96% (-1.30 percentage points) after distillation +- The model overfitted to the small WikiText-103 dataset, causing performance regression +- This demonstrates the critical importance of using larger, more diverse datasets for knowledge distillation + +### Recommendations + +- **For production distillation:** Use larger production datasets like [nvidia/Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) for better results and to avoid overfitting (see regression case above) +- **Training duration:** Train for more iterations to ensure proper convergence +- **See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices** diff --git a/examples/puzzletron/mbridge_distillation/distill_hf.py b/examples/puzzletron/mbridge_distillation/distill_hf.py new file mode 100644 index 0000000000..d21f35ec16 --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/distill_hf.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Distillation script for Megatron-Bridge. + +Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model +to `/checkpoints` in megatron distributed checkpoint format. + +See `README.md` in this directory for example usage and data preparation instructions. +""" + +import argparse +import os +import traceback + +import megatron.bridge.models.distillation_provider +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, +) +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + GPTDatasetConfig, + LoggerConfig, + MockGPTDatasetConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.distributed import DistributedDataParallelConfig + +# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure +# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers +# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge +# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration. +# +# Note: Currently, bridges are also registered when distillation_provider is imported +# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider +# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron. +import modelopt.torch.puzzletron.export.mbridge # noqa: F401 +import modelopt.torch.utils.distributed as dist + +# Use local copy of distillation_provider with fix for heterogeneous models +# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge +from modelopt.torch.puzzletron.export.mbridge.distillation_provider import ( + DistillationProvider, + convert_to_distillation_provider, +) +from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import ( + export_to_hf_and_copy_config, +) +from modelopt.torch.utils import print_rank_0 + +# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider +# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time +megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider + +# Import distill() AFTER patching so it uses the patched DistillationProvider +from megatron.bridge.training.distill import distill # noqa: E402 + +SEED = 1234 + + +def get_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") + # Model arguments (accepts HuggingFace input only at the moment) + parser.add_argument( + "--student_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)", + ) + parser.add_argument( + "--teacher_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)", + ) + parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code") + # Parallelism arguments + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + # Dataset arguments + parser.add_argument( + "--data_paths", + nargs="+", + help="List of tokenized data paths to load from (weight1 path1 weight2 path2 ...)", + ) + parser.add_argument( + "--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data" + ) + parser.add_argument( + "--data_path_to_cache", type=str, default=None, help="Path to cache the dataset indices" + ) + parser.add_argument( + "--use_mock_data", action="store_true", help="Use mock data instead of --data_paths" + ) + # Training & Eval arguments + parser.add_argument( + "--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving" + ) + parser.add_argument( + "--seq_length", + type=int, + default=4096, + help="Number of tokens per input sample. Use 8192 if your dataset has longer sequences.", + ) + parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size") + parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size") + parser.add_argument( + "--train_iters", type=int, required=True, help="Number of training iterations" + ) + parser.add_argument("--lr", type=float, default=1e-4, help="Peak learning rate") + parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate") + parser.add_argument("--lr_warmup_iters", type=int, default=50, help="Number of LR warmup steps") + parser.add_argument( + "--eval_interval", type=int, default=100, help="Validate + checkpoint every steps" + ) + parser.add_argument( + "--eval_iters", type=int, default=32, help="Number of batches per validation stage" + ) + # Logging arguments + parser.add_argument("--log_interval", type=int, default=10, help="Write to log every steps") + parser.add_argument( + "--wandb_project", type=str, help="Wandb project name (required to enable Wandb logging)" + ) + parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)") + parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)") + # Export arguments + parser.add_argument( + "--hf-export-path", + type=str, + default=None, + help=( + "Path where to save the HuggingFace export. " + "If provided, exports checkpoint to HF format after distillation." + ), + ) + parser.add_argument( + "--hf-model", + type=str, + required=True, + help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). " + "Should match the base architecture of the student model.", + ) + args = parser.parse_args() + + # Sanity checks + if not args.use_mock_data and not args.data_paths: + raise ValueError("Must provide either --data_paths or set --use_mock_data.") + + print_rank_0("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print_rank_0(f"{k:<35} {v}") + print_rank_0("===================================================\n") + + return args + + +def main(args: argparse.Namespace): + checkpoint_dir = os.path.join(args.output_dir, "checkpoints") + tensorboard_dir = os.path.join(args.output_dir, "tb_logs") + + # Build student and teacher model providers + def _build_model_provider(hf_path): + bridge = AutoBridge.from_hf_pretrained(hf_path, trust_remote_code=args.trust_remote_code) + provider = bridge.to_megatron_provider(load_weights=True) + + # Override parallelism / training settings + provider.tensor_model_parallel_size = args.tp_size + provider.pipeline_model_parallel_size = args.pp_size + provider.context_parallel_size = 1 + provider.sequence_parallel = args.tp_size > 1 + provider.seq_length = args.seq_length + provider.pipeline_dtype = torch.bfloat16 + return provider + + # TODO: Support megatron-ckpt as an alternative to HF checkpoints (e.g. /path/to/ckpt/iter_0000000) + # Still requires an HF model name or path to build provider correctly + student_provider = _build_model_provider(args.student_hf_path) + teacher_provider = _build_model_provider(args.teacher_hf_path) + + # Wrap into DistillationProvider + kd_config = ModelOptDistillConfig() + distill_provider = convert_to_distillation_provider( + student_provider, teacher_provider, kd_config + ) + + # Build optimizer and scheduler + optimizer_config, scheduler_config = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=args.lr_warmup_iters, + max_lr=args.lr, + min_lr=args.min_lr, + adam_beta2=0.98, + ) + + # Build dataset config + dataset_kwargs = { + "seq_length": args.seq_length, + "path_to_cache": args.data_path_to_cache, + "random_seed": SEED, + "reset_attention_mask": False, + "reset_position_ids": False, + "eod_mask_loss": False, + "num_dataset_builder_threads": 1, + "data_sharding": True, + "dataloader_type": "single", + "skip_getting_attention_mask_from_dataset": True, + } + if args.use_mock_data: + dataset_config = MockGPTDatasetConfig(**dataset_kwargs) + else: + # Convert flat CLI list (e.g. ["1.0", "/path/data"]) to Megatron blend format + blend = get_blend_from_list(args.data_paths) + dataset_config = GPTDatasetConfig(blend=blend, split=args.split, **dataset_kwargs) + + # Assemble ConfigContainer and run distillation + config = ConfigContainer( + model=distill_provider, + train=TrainingConfig( + train_iters=args.train_iters, + eval_interval=args.eval_interval, + eval_iters=args.eval_iters, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + manual_gc=True, + manual_gc_interval=100, + ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters), + optimizer=optimizer_config, + scheduler=scheduler_config, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dataset=dataset_config, + logger=LoggerConfig( + log_interval=args.log_interval, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + # Weights & Biases logging + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, # optional + wandb_exp_name=args.wandb_exp_name, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size + ), + checkpoint=CheckpointConfig( + save_interval=args.eval_interval, + save=checkpoint_dir, + load=checkpoint_dir, # Resume from this directory (if exists) + most_recent_k=3, # Keeps 3 most recent checkpoints (not metric-based) + ckpt_format="torch_dist", + async_save=True, + fully_parallel_save=True, + ), + rng=RNGConfig(seed=SEED), + mixed_precision="bf16_mixed", + ) + + print_rank_0("\nStarting distillation...") + distill(config) + print_rank_0(f"\nDistillation done! Saved checkpoint to {checkpoint_dir}\n") + + # Export to HuggingFace format if hf_export_path is provided + if args.hf_export_path: + # Wait for all ranks to finish distillation before export + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # Save rank before destroying process group (dist.rank() won't work after destruction) + is_rank_0 = dist.rank() == 0 + + # Destroy process group on all ranks - export_ckpt will create its own temporary one + # This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone) + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + # Only rank 0 exports + if is_rank_0: + try: + export_to_hf_and_copy_config( + student_hf_path=args.student_hf_path, + checkpoint_dir=checkpoint_dir, + train_iters=args.train_iters, + hf_export_path=args.hf_export_path, + hf_model=args.hf_model, + ) + except Exception as e: + print(f"⚠️ Export failed: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + except Exception as e: + print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}") + print_rank_0(f"Traceback:\n{traceback.format_exc()}") + raise + finally: + dist.cleanup() diff --git a/examples/puzzletron/mip_sweep_example.png b/examples/puzzletron/mip_sweep_example.png new file mode 100644 index 0000000000000000000000000000000000000000..4eb1089fe0e980cbfa9d9839f524352a6c649db7 GIT binary patch literal 53715 zcmbSz2T)X5v~3$OV88$(C@?B0AVIQZ6amRO(;zv6WXY(bj3A3~9!UqSvmW*XNCx9Sj)eDX-;MfCiT zi-(HpcKz~FGP#n^$t!1=FJ*Q<(lUER?5C8hCGsa3IhvnQi79*%`wA{#T>22+263px>@oUDrpk?u)$bGWPc(?oCBovYoHF1XF+? zg;&uxV~#$e7=`>4jy}@ck-|eGKMJGtkB=RGivO^@9OMu zmHWE@F86gS%(6*)taZtL{Rb+`)iRs@!i=Fh)Xi-7jYV#3`+fDCtmNcmfp*Wa03Lf! z+3kriJ*lZ<@!#q0-kpWfx~DsmoK{B3_uo=dQj3g^#VWBT#t9Jv)Wzsq9u zHD)c9CW@ib>Vtjv+Pb<)w_SK9rdlgheK?z;cuqa?`xri_Owxi!=6>(qb!3+eXOioA zD3Y`4GW~Xy_;8cqTNVoSz`@S(I7%TlB{_Mzpk1iTpS9Y|X+Bt9NYTO~^F)3*+l0E& z_WFP=VX%>Qwt&TAUYqZFF>WNJK=$-WDuEp+PP2kmtrw zN({9^rb$bzO2Kz!$Cs8^d3bn;h=^u-v~6+jOI1@Jv03dg z{OU6^GpW7G@`A>=MNDorPq(e?D`4@uC5bSHD2X`Gd-3Fu`TmRiZwSh9}0*@~Ve+lo+|u z>y5GcKBR*sLStd6S~%Bx}3IFCXNe+y>Vq>L6_6Ava%`-aYB16K%(*uR~HBOv5uP{{Zq?(~b z=j!H`DHp?6zA@}hSST5Z^sD9#pyXw ze|@ypA6MaGOLX#NLsyrwY$Rt#vSjdz&a0@2&E>I30oVM$Yih_yNoieNT(&kh={Y%7 zMBkq8P}lQJDJ(2}Jms~uVhmrB@?R?wqE5hC5New zb6+g;py?vW#R*xjx_g14yG?vnw~yJI85>J7jdc|l-rqFL@eJ#BP*>91(DT}&!$jAZwFJB@t4YmwdY^^PLu)E4KqAb*C%nF z>k+Lw#1(8?^1^$`9Z3=x^RCoPow9r#G@9(vrTCs~?YZT88ho+!z>0uy$&C~v!4X)M z&lMFFnJO8H1GZQ+1nQOoVU%&mF;qiWcZTu>$<#myi^FTKOT$&NQ;u^z%sb(=25h=7 zei31WQsrVS94cX6HD-gQd8#|nG}ruVU(Dy#zaJkrirS;4qoaVOX(jA}e^PkwZ>6bL zYM0n>?+jG9hDEl0<^AW*ZwM6)Q4%TnzTmbA=bsJtjc-qh4|?C&?{ z>RNe5to&e-H^Nx;+6rC=r3OyZYl>S zC$6?wK}(1Cm+y#ccm}Yd%R?;Xc)_9?8N-mrG#P*>rcpEm`?$W$irwve?&QS$dsSTp zkF7TH9QO!t-6P-&2K1`jaKD$A+0Z*1Bi`OGtMTV+9)z->El!>~HM#RbtI$mUrjSRk z2@OoG&;tAX*WWrjgx|jpFG)^IyFntf`7Fnf+q%EI^u=;wD_d7*=X-M97vG*kH0IfW zr=z@l`7%-49z6F0h!WPizPVNFTCpSd@KF!#5>*GP+&3ndaV6_uTLYEc!9hX9xORfn zO4#iQ&gTcX271o$Adq?3+SYd-Ppr zH+ov#9)rPPmM5cZ`)}^7Of-x8UZf)4bqiXZl7!6Ek*kM29w%%+-f&#dOZx-0pw>^F z(Z`6#!*KgsjL;nY2b454z1|_r+;5 zT^sxQ^2A&H>vfAm6|&augjwSPvuBtyWE2$UFWlY9H^amw>(DMk#;`=%1mihblQt4< z+e9m}1m~q;>&>k#vWNHJXAD;5%WN90pMv}PeBNw&^7*EU+0D}Z}i#FF4mul-&up| zU`qJFoJ(9uDF|ms$>(U6qo;>FuZ1eD`+A?Lb+Pr2_?H=I$gcT;F%Rwjp1jdjDdg9r zc6k=ej1!#%akt3}z4`MsiHvocY?MjC!)J5u3kXu^Hq2DZj~&Zmb;;77<~7o<^urp$u_@8{Ag(3X5fE=>n5hcECoarDbIiqhoae1zSbu z&z<8as+b7XCf*u)b{#CCxXEyLbNS8UJ!wf{7rHH{64gIfHa6NzSjoxB4d;5Y)%GTs z{(kY|#kkg`M#jhJDz_pL?=NNhV_Q>kC6qgW1DI84c)epQ2|e0x_^|0>>5!>zJbLu# z4JF2UpjaG$lByo>0Ax7QY0EIe4!O_DOF_V8f7J!jZSV#ZIA0nSVeNwv-djRQf8y0E}V1h|~Pp)2W zE_T1);yuJ|35dbU_V)QmTm&#fH&V&Y&dv!T6f;30=)SIk_1YVso5O&&igK8?sla9> zTQ)~8Vh_Bi1l?2s(6za|AN^8vYYVa}og`$Z+*Fwe1Kj>63F^E)yeNsFky9GIe{CZF zQl&C1pBiLir^P|uY(4DoB)iX4th*fcV87xpkj>(A<1$q-snTIA47YFpeQW275%XuO zj>IC9B*B-C)i0 zmYU_om8$?NgMzyboSA#`js0=?F|efJ_L7p4jr)TJdDAaXQ!Av&#f&d3WMkBGXMg{y z=`Xg9@;caa0mw+s*uD!+W;x#QX}jHf|BfV|(`-m6vvw-LnmND+w$A*3kOHXq)vn%^ zX((w2xEVp6L|ax~ZqfTif3`O-iRqfZTDDdS;4+q-&8^ib!q*cdX0TUV>v2IriIc5S zbGF_)QJq~~gf{-AX+SG|^$=c@A?zEo(p?}S!9#W;KopML-hx20w&Z=F{h3j<=|gL* zpf*H&t|A9#=bm)Mr0x`Ha=_RY>+^m6d4|3WaS5Wnvkjl_09a|^*eET12rhy=27(8# zT)9F@LBRkh3E2%%--}$#2QeI`ABH_Q(5|ko#`XJ~VpW>AIJnOQMWeuRH^3e1p=6JCWJf4{eKE=jb9ZL6Emjqp35$A z9I!6QFl~?R%d7P#kJ>(S?pzl7K`;s6(Nz3wL&HUA_n4Vm zC)c*3;rhLGZ49@(ARz>|y0+RbtX<>|pi9PjUhpd8ojWyW&Yoo>OwG?*MGJciRmj40 zvg!(pjHV%jmRKOa_t+S6;jR~Qo)6En#OKVnCnYAHhZu$t+Fy-3=oL7C%B?O{HcBQh zDE2H}*g@=dFt;D_h!gUqw+^`NjLbjmXZDY|o zOSh`O+vfXUm*495#wG-zDa8qRiSWN>J3&>sED3W z$4goY$a6eMidZe<_uMYTU8(y(#jnL2@ZQgl9deF^1wKiF1jgpbv1;vZ%C{X6y2H-i zN!rZYZ1e3EQCX}vCjW8FQS=GuA{ziND= z`ky>`qEXbiHF%7Epf~&`%CDiXua7fnvgz~vatNEGumV&!z&K}eTwU{uc<28-do^pn ziT}m#XG4x%vZ=jg4pW6-wDaxwY;98<9&o=*Tt#W=^c!;B2A&c@$t*ck;lg8G?y*$0p%r6l`PyE)OlbFa4JBd;d64NaVBWkt zt7P}pG^{^(f!ECd7%p-3diE@y&vE+Ms)nSUlM@c$L-^E2a#&bcVq)SN)Ppksqtv0O znHU?pV`@4z-)fK>$3=@CxW2Vb;j#XL<<$A>@sV7Xkp^*I+ve*B0JpNLd1wHqcg?4~ zqJ)hGY;EA@3G7DSK|8Pxm9DEY^`|dT(t3Lz%(o-3i*Fy+b_TFXXPozbgrMjH|FksI z*r;rBRVYv$Aw-yXSGy4`kjf#I8`3Z9b8p_bk*{AxfIZPd??Bz4Hg#>h{{WZ}oAQZ@ z>9UYkh0lEz3G&!{e*sg*D@97`6Fq)P~=>lAfEE^P15EMExb1S)L+g2Lx1ZoVu z{Oz{TU)051GHBq3KG?gD9uV1ggNp8BT%2CD;gr$!x2p5o6tCYXQbKI>9CGNYbalYs zm!ZBDgk@HKt%ttjlMG_!-O)v(Yy>K5m-G2ibO7V z*Fi`#1%XJ8fP#%^vAM@_6S$%QtAQLe{~lnT)ZTBePa}025*?xVNHv7iQMCO=C4aBk zTW)7#sl;zWO{Rgt^uu~6Ua4|U*c%CAdAZfwj5{ABHDN5UUBQh(<)1o!?q23*_?#G| zEHsUZ;4qydQ7PCrKrAjYdF;+2Kf3aTBHRrJm9!b-rc_h)j-Ecu zbgp`5vFkLA$44sdLbFf6uiS#ogSxzF*2aN1ECW1T9%`9Bysnj17PaS^q9mU9I3O<- zV9s!Ph7uQUK2&Ip;$cZ)_tPm_o67P2vKN0}gMU8u;VIb(^$7iZPtDgUQ1y#6i z{OURTg()5O(1J_!l92TB-9#-hLxhpsVqNKhLUY4{O^Z3Mo}QKBXkz=h?RpyT!NPG@ zxL1FHX(0GPDxgA5<~`;nC4m_&yWiuD!KoTHE>g|%<^nEs1Si(%7Z8(oo9EV%i7hR9#+W9fSSGdj;X*%?d;UmbfS|v8w z+O}1ydpjFEWb65D?QYB;ll8p!xZK|r06Yd%mJ7%w6PQOzPP31#Q#`_NwyKbPs8#Jz z#t=HPA>A6urHb|PRBkcIYZw_BSu0|lZ@-WFP%A@^qJIm?Qw zL4cl^j~|PcRN4gXlLu|;8x0(H*$i)ce%+s;4brg~ZI4)dx+(vO1##RPSNhU8z zB5)Q`Ek~`DVB0GJ;dMS?o?;t!@4eV(Z}c&jT&o~?Bcy@Y_M83f2>Eq4mi_e-Y~#?7 z;Ld7DNXWdV6|dP~#8C?zN0I%66zq0O zrftmw$L%%qz1`Jzgw5zq6gxKrR2JPz-Aqo!q7t~aHTC@#ru0#rjUdviSA%V%z;nB% zG~H*%apL)c%~RVEx7OX5Tcy45L5>TiB5nHxCZZF&``PG(5S(@bl{Z)a^#T;*koVy0- zSc@zz*)lf;wg?Bz>{Lck;Cbn7!7TKmwpw1SCtF)nF`vGhU4(siWBs=brlPp?=PlNBtv#7NdKji1PM1aFG zWJ*gw6GV$?oH_2=wmJTtckj7Zg-+68Eg3MfZBc)C5OC^Ja&pCtj{|P<=0BB|OtF+k zJ46b5dwt@@vIb1$Bv&t%H8Oi^+v(|z2<$X1&Dpfi)ipLs9&F>Wj*~4Ep8H!E+F z2#a>*iUFbE_T9VUfQyT^BL()>y8Qh7I$=LJy>?$9^b=tgSFo+?9~g)S+>>pi^eHZA z->Y$9i)Gec8D)Ipa903MAYmj`0}|}Hsu6j}SXl;Tdt3uOR37>Jl#5eSri8E&e{$CT zO7}K`w3B`CYVQ0sf3YtP^E>^ZCU!VoF&fJ% z2_&Ri-rUmhkcsVzx<2JLo6&YCcUmu&@_C!B{qBD#%n6Nw>cIVDd=8o&O}ZgCJ?r#$#m_FukRcv5k0^<7P)WO z)YjITx5o-1Y=4Gz+v7j0-cxYYalV!*S4(bQ$}z^TuCJ@XO=a@FeED)*Z-aN~qFSU? zRjAEK_0Z(T0zo>xpf^%eMFk3R4l)&>V3Xfny*CTlwl6QM;wu>rnP>GdIg1Wr3g;a6 z2_qxMGi3$j+OAT?-uql;KjJkE``mU%JXU32=0*X`)Q+LKemxyZy{B;=EAN1MjfZ^B zz{1it+h%|E>{;#&vlRjX2ZbD$=e8rbshUI}wW@(ZA_B!3#f?o&EIJa!-ndmF;jhvI zi=uZ0#pLTnYWMj(0b$m^qNALpIstiPgC#3N*(%mwf6|XC!SwkmeWy~v^HXc2w+HI>#%vt{e+}YgRd>@IJq?mxJ zrOH*gD7`c==o#?Pd{?vv(={%j(JXR9g8F569z6J=5>lYSv<*N>ciPKdcL#{9CKuZ% zV9Cu!zrPb34sM_ZJ~hklEK<48_Z7(Ib$6>ol{mFqw#WTVeDPvW{G-=^P)4xm3TIU_ zGqX?i>0tcp`-b=(Ja#`NOwWrKuW1hYhTBw(i%eh%lc39bK3EpBZe=g2|A!9{3=6XQ z*TE(<3?cQ^TO4HJ^}L5BX|CoQwZ7j03zc04UY(wiv1yXGjBGg9TXqX-V*8=7rlBD; zgx0_r0?X7Ek^bC9DCEZ<7Tn!{XR<@jl>-B0lZ^UV_pNRm_3Si>z;7S(nvv)m=GVS| zyd}zM+(g=?*FbXfVZ8G~{|t<%CL-c<`_3IgozNQTvI!P~>7XuOVy$@biv% zv7GHs`pjT5T!X;=EiW#0L{FwX`!p>mg{oKJTonp#_Bp(g2u zO^8|AzX51IX9G6Da&@w`FK+54wp4mKe{*}=;`h&Q+oS%h`G&k3oXx!2jLq=f?lgHC zL|lSI+$GC1lnhqfI540MjtRIn7Cv5y0G?q5;0sjQHa@HCc7RQ0JbHa9LorDl z@-hzgZf$1?i||frpaCLWG+!EGrtAja&mfcrET@u+N(dt36&)x;LF(sO-}&+5CCE6P zh1gEm90+8h7iffZfENE07Z(TBzatRs3xj3K*!^uxP*4!3{rJ|4u~B@5IT137&6hm0#h#AVF3)? z!W&Q7*}<{`*h|?02Ugw6bkHwm0X3^Wefk%Iw;eJzO6|r}fU0>Ma610g0~E?N62#V9 zjcaU{Jz0+l@PDQwtpc|QYywYZWo1_XB%!0kR_~6W;812RM3E`?oZQ@8)hv)F)I1k@ z^J;#tuHIMCEsnjX-7cyhCtdU4JyefM;W&MOZZQo`Kv`1xnN)ToeKNCEox!VO}_FjnAT0SlNz zUgEZm6LgrA8S>g$D2vrAvQW(iO-sbJ2sly@Lbz?sW7H3zP+@$&^9rE^arWOqkWGiG z!8LYjV#2tn5{MJD!(qA=Y@NJN6&Qvwyy-?2I^EvRr;=0AnZ!LAc2PLPPT$IZHQ0aill_ zc|mn?aS>EM1t{czf5yZLx@Q7xajvRSyrz94hy=`3#5NR6AQeD$gaML$Ku3Ky-wvF0 z4x-s33Ou6s{`J=xgwW5+%R^*pfU3A;$*Y9Lp@B)S>%fRa@;ao(dhfd;^f7jCJ?bL0 zKss1z0}iJOOQQ^Tx2YoZ%%}L^EjyFnK}1RA8Ga!O-lYJ+&%?u`J5%+30BBh?z9c-% zPXVIDHDTRvKH9(T9*NY zLd0;-orMxDp#6(&M{+2)2vd9f_Jy;lopF!%%jPO+h(itRr4K)`Hv9NyrL- zauT<@fhWE&z<@9P29*DY>n}kzS^Fd%wfUQNVs&YwS|HDFa>bNXj@k){Kriv#jqu~lCpEZy8-nS+#+ zR8Uk@6e4OP^96F9QqavYFZ47%7bLr4s}8`RNOD1%VF0#DJ$?EyFfb4iX#lmGfp|c~ zqIN93lh5QxLRm$O<PQ<`%CIYAVq3a9#50uFY#7a$2E)cc`+u)4OE{1-8* z^wwcRm~kgi*g044$@L1X&th^zw>-_B8@fZySmd?m7F8eXwLK$XD}q9ah}9n6G6y6@ z9gr>T94SRUIWF|8DyPc8XB6t;Nu+N~VG`Of5X2D)l1cT}A3w$y7II3YDKb$|P&73(Ja?q@sd;|kr+7UV>Nyok=Y+@>_F`veXUNBo3ds)l zTV&EZGaw^F%n2Gch;Um#szuTPWo zK>y1G&ycPLenQjw0;MHUMDjSYeZv>i{~MbV<$~-X#8S`VIHTDXFOz&}kQx4BKx+ z9#1<1b(R|7H^jhz_hTNZh9Kqi>>)wAH*$SCP z8S7(EOcWsqWa?q70Y4i%?GC#yeWnFhU>d2;CsswOP(C#f!@*m1phN%$&viKn>HkEb zO238dOuT^hyAWCkRsEC8S(}igRu=ACx@>@d4+!W2Qs@fjMG9kn=XnLtj6niIQ2YAp zqwYxsBKpUOwFQCLX0ZA~Z5OtGYn|-+^>EVxKx-4OyU;bj?kVKDVgx-`PobvZX1>}Y zZBx-g!)yNp@K)=v7*+uAxRWEXkEq>$F-VZ5$LHsDk{z1&>?^^>B!fUxXdMI74oU{z z*JEe4zX(HXzBpJ08^A^7|Neb9L_q6x8~}!@)i!>FPaPc_1E|G-hzCfM6!aM;pQe5h z|D5E=S>j(YY;YvZ&maUbWN=3?b)*juDLz%w`KuWUt#*Oc_sN4B@0%DNKGYx) zT>E&~>PEzO`fxkrZ!^NPKn!M<$gvy zqEn}4As_P;1ws$%%M+)q7Q2uJ3kzuN0;cHA@e?NyV?(S42s&si&2w9u*0+d*@LNyq zF2h>+`&ECDWg4)?=+ftp5F7?F+{DsSF0_C)H8*3P#2v*WpZu)B_CJGYNk~W_n8^Xl z848w<7#XA{GeE}3o*;2=5Dy)oxv&Zru-ICq&l8Yj4|2W*aEVBvuet-&;w;1sbr6Oe zfB$?0WyVt|*Hu9NiFKW%K=c|S-RdLzQBc%`E~}d{Cs)FpB&wkk+TrK>Npb382k}FW_Rt#n%Zf=}Qo{A)`B zCRJx$Z^e`a#s8ILx)6fyL%1a5=jYoH^XpPFvg(heB1a4MWhIp#TWlA3D|eE`MW&vN z%5$VkPBpG)Kflwaqg6Rq`uB~K14Z%@|8}M~eJ*@2LuW$vMyZQ5&VYrs&(D}FtYx8- zYEQx5X{rOrA zgWWp@^AB?07XNt&DVhJ(Zrvt{c$<*iy(VBvI>B_;m{qy?vUILblyw*>tBbls+B27b z`s|%u7B1@ha^73{Lb%)d-|yqOT6%Y>_7;^<0>?Z17WO8soY1`kozNn>U5s%qA*{K1A~!9aZ0Kwsl{8x^}0S@V!1|hGVjJnTFqoE zu?WkGvx92Nytl*mL9{9Z=6Q&!LOILi#N=)Aduaw!9AAtDD@YSgapr2z3pG*xv(gof z4V-JE8`@wozlQ(J<$W%Gl~6IfIFg{l(#~;SRMMQH%qAqp;jH#3p+4V)YL9RLnJL8-|%j9~FxOem$QwXhP*$R@Q(k z+m!BMEWd1BvsORjPoxzTshlg$#g&MDeWjsk`5$AOHyGexWYB`z6>Oh)SqpuWrp!vdqRxuF1A{UQ%Lj8Fp2q85RsB-ck-ui4qMW;7 zsi%)s``7B0huRFCXH19T2=Tlfz^998-A~O!o`Sg{KzCc{K4{}ow>7`3Qe`ZP1ATqp87COcv(ow=vBJJ4MwxbOcp$H3)v@s+}Jz)OK zPZO5B*3t=I(`XcZRIN+APEh)ooAzq2l5Bo1CB66Lo+Y8pU_&_Gb!hb7+4z=(qX)kT z!LU61jL;ur%DxoOmt zi4o}v1WN$Z(Aqg3ydpUs3w+w9_3X~7yG`tLVn)8uD(**6l0@MEg{v)L72#A z^fTgp%-0ZAv&{J!weh{r(J@R&Mu7-{Y1j%d>+)&+_LA7|=*SQ0C;jW?A>k|PpBN-N z^3LM-4zv8av{mI}u;kI|p|XVF-cx;XN~8v6R@cjj?VQj)`lneKYa^K_k0eTL2QjL6y3t{p(j zU96UlE(jKb{FVD|yGr3rx^95QEWvAmDpuGhNc8Y@($@=(56;)#O)+5&Xf7iG6P1Mh zV=3~`iDt@-Rn=)4{yiYz-`Je(QTb-|xl_-;9>1B*xW4Drhz)8*uDvR*MA;JoW_X|D zqzd0Y2dV?pi$c*SLu88am(Gv<<~S59(`ZhS5*nE6J9l8hJQ4nlbbcx&bC{aUf|V9L zaM6|~^_TSx`q61k=3ytDdjsw7bJ_zWJolEUKkw+W>o zoKdwYS%MK@?2Ys(9?{I#R1Yj^?yl;}#P+?pYwLnnYD%@T`|nH~_3-$SjT_4yIm`WR zi&u$;!7|O``NXo|GAW_*bS~sXJ8pW~x5LJ4?7xB0q>tZ%xUJND8r5{?$T+0X9#Z{9 z2AxjQV>t&0Tcxzdp`D!IYGAe;?wUG3oZQ)QzgN{h=7xC)*tBTMpq|MeMU`Fz$AoV} zj*#MBH!6=zFC8xgL+!$8e)AgclOD4eFRjpdMkim!$Kih#B}I8S+4z&-RsAp4vO?$i za>mGW=oth>g{T*aGjt^PqVTFy_N5FC1{US2`V6z>{LL|e(n@aTvW7u0nLF<9bbdR* zT0IWV^~Ep^I$c2BJUWV|aGag}LhCz$WWXrdw({Owny1_9Eda-4{+dd;Y<*+8+F9Rt zbbSpnwEx*gzYj>%l1vqhNZi+ACy$EOve-FGB9mWCt8}B1|`yQ zv{my=yv(9k!}fBxPPdH@a5Hi%E@kzl2+Qep-`MMKt|>A)#|V~jFfhK8+b3h1I}ki$ zr4!W2&MwXAJEux;=3x#v*JkI!3r20ic1(om(V?||NqR3M%Ibv`vm_BGo2`VQ%@V!v zV3ee;uCa%XX28N^%=)>aSfXpH^{3Xt1 z{Yu#%SS+!uH(4yBds)~!y0<8w{FjIB{_Gxp{Ux7HmQZ!Cy~rc4*uL7-f)2iW=GyZMUqapQQF(!{Nxa!2s0+=MFwT_yd;`5FZneEogOoWHbYL8O#J&C)9~g;sezfgd#;<|x!N z^+VTC$U#!aJ9(vU+>@99*_Qo{?N?^6T8v_f~lr4jG^i6K77H=3!;c9UvUPm*i5 z5^kt5e$!&x9iwcf@TvVYDn9El_4-^1j*ZAzwUg`-OU-J<=H6xGWlmDD&wC1?BLT_{ zNMmXr^pc(IBM^vK`!&6wt)x#ITzY!?5SJZj1CnDZTfueUrock4a^;TK5ULVU>eT?YTMs1t*nBs6fJnL@9EopDP|kgi8O>)Gluik4-A&Ar)!f14eKX zj6i7(XDK}iQ9dI?t*No5wvW@FVG8^5vw*a5hK{+tbN)0qE9G$U%iAAfaii#Zr`vbc z-kHy=EUzj2g)ISmjgd=uG=Gc;*x17jr}t&QAWD8UD3>fOpL=V6ZX_#FSw{*CU%9V& zNJpbd(P(AE@%~lmiJRa08WVq?`bKPfDJwkux}*|l^REkzqvF>NEs*}TQYVMwdS+QE z>E|LD7H&n;v7nGpMLy{XBx_ENWl(xA)Hr&+%y=7D0-My-^;z%!F_g&grZunv@gcAo z5ccNEbn>D8$51{7hsE17 zQ;LYU2BX0_WYGToo1d>fsO-<~HKpczUk3m$#X|7Tj>mFuw=q7%{U!tj`p#XbPyxC; zgc>ob`0qyla~>(b<)EZh-nPeO$-q}XG@b}fF2o`^6OmFdQQ{vg_RYIhtK5rQxd#i- z%-+S0<+BcK*~()iGWVh(u|Y05tkSY)a_`Q4-O|X0kP6n$nP`1t2+xw$G8j4h;5=GV zifc%fWjk~0AndBL6f+Hti#m~B3Au9HZDkR#RYOG8)gy7l1aQmN%dUkb#Y^B0Z$2af470VGykBAL*5_^akx#wQ0eutt@fVrh$;)FAw>1LRBJ<_ za-FL~{!^TXfUH=`GjOy6CzrBkdaSNx%NSTQkw?o>S7GV=|IH(&u!|e_g!>qhDqTw- zMwzlrY?eT_)``|&1e)dwYT)(OZ zy$Fn}B62jvHCtUg%ij4f%ECx|)IOTm^i9g-XL`8?1jViq?U=`}=R+>TUz@HR?g&dz zSEuTrsv{Jn&uZ-8#z&x>zW95$qhRr~BrtIQL5ZW1@&vyL9NYKF{}VAh`q1P;$34Se z;?}}s%Hu+AT$vg8ouaU-H-cc7R5aGh!q}lqxH$Dz;ou%i+=$0Ujt-NxIE3PJczGRr6D0vud&zK~{emN` zpu(vy2yI0l0=QhT#Hhea(7uOs!_Ea)UsIkbsbGBC82Edi{KEiG{{^`S>*(*rOKR<> zZtv3JTk|)LGyFIVfU-IaLd3L!9LiU*2M9O9R$Dcwvo><;S9zH@+0i&FKFAU_sDp<| zHW$xM?WM`u$*Y^TUhfoRew@ii&jAC4kRPr`EPzm1&1%*A>#^*$f;!P$o=?e(a%#oL z;F5;JYV;#UrshYpu=uR6uzchwRN>q0Kie#{`F}i$fCr1OJOn3JO`PYeiVZ5NM?C)~ zb#ZX++gDOPa22z_5 zI?9l9XA8IH3l?W`{#y&O&;M?xEIUe_HWR6HC#NFeDs z$mr9)7&$6;Z6=^%0HZ#3>Bs{75+9Q$0EwK6NEje@@V)pQ4{SfuE;T*mGJ*qk8uYK} z=5W%=ezvdJ#KmQ!) zzDOtb{BHU&)R#CC?;%V?)WvXvYm!}y7PcG1U%t2pUSzbnZp>v>ZHnI(RxHo7) zmzP!2(9m$MgeOFfcp(})94WXqX9oyVNoe0XS7^AjowVSo8=KaDEH9WTwjZNF{OtDq z4Y72P4WxovkN+SngK987p#JRHvsmjuDt;3u0qimeScve@9L9Pdny5w~)zV+kJMenC zls#lk7WYE?{?Gtp_t9X9$XhjP$!KU;C0ABf@?8g^M=yw21%;?@lPwWR(3LQYQ`h?( z6{WqN2S??gU3(xicek#+=@&eKpT9q%YsK<%LW{l0E}S#W)Go<^_Q(~_o5ree;37#G zb<@uhWn3F1Xu9_3`z_4o=7$T8mQ>g)JDN5BNuK^sYM>0g`{C;=KqZ5RoP~Dy9#CvE z;OHXK7=WBv(9y}k%|}PS=!ZDam0wkSPGQU@BljgwUn~w$V|zu*;?90{;;Dqj-KhG4 zb=M}_Q%Y)j!n?m;BOSWX$SuDO!fbvZ^c0IAs@t&J9CKH!;#PoO?@jP~LylN3QN}<1 z_+vzbRna8HqD#&jut}I4m_u zwAj_$EDf4nC#cji%g{F?iS*qRAWd;WaNq+uaS2C#(?O1_rKUll+_!(VUyUpmP!j*J znS_qCu$sDivGFVn&l78pEmQp83wwQb(CUPQxuC*l-(Fq0L4)#1)X3WixqI*hmxZyZi189VW+z;p^-N17TA*{e+$bzcY_(48S2n{`!J;<9~ za3jz<9}*jDd;Zf0If1q92B#2WOI8(AJRQ zyf~NvK^*C%hg0|I#Wq8kd4~0qZo|;WnhWi;84w(i7IK(vvTl`|IY`4uD=`%{9STLX z+MGZytC8Oa`vrx8U{?Vz5kfmZ{ImAK`{;ou@#mY$?~FiC$IWQT4y31p5nGaQ($C)# zmGTQ;)8S^X_-Cg>*0&q|<<+{C*(svu6m(9Y@W;M#p+XMk;vW*;5wl%|IQ#TQU@Wd{ z#5$i*m?1D_+}q8v{P@b ztJkK5pHo2JIEj=o~1PCOvdi#AekWhhI%c`W)%JS_1Yqe7=e`Z1Dn>! zU4H&}qan!d@*2t$+2xB)n3<|aR1=e<^YJZkkS&0;)hcG&R_h=># zlTxqhTKeK0>?Bh+WxsvqGJINEIfSM<#Cr;k2)3jbO*Fj9Wzyy*rYFl;8+A&x2XO4f z*rTJ{Ts^$aUk>V5;Wln!wUOr(0y1~zlJtMA5S}|6LPC+{LrXaS5Htxk2kzdJ^vEEX%rAs@aP_Eebgx7ga@JVubh<)dzZsueD)A~`TCOO zF0JMK-~~P}hBne%b(}=phv`hlBa3=;c*P@FwCPV}*jcj{Rpm;#&r;itPt;Wj>(Bog z2AR#*V+O%Wga#PK@m3FM?!;2Iyo94N@WqF?3l(zy(5&(?k{NE}0~h$759|13f+E$v zIF~vrbab=jL-QpCLogq#)GJz=CsVF`8y6;YqeB&{`VW!_(xulwTtPa7SKNV%?rrBY z=nmB@$6Z{?gqrUN4j&#}LD3Vv%b1}QKHypKD>w2xt+w*R^iDZLvOV2DGdeQ#FG#_~ z4HrH1>>wriU9IMmi9Ryoqt~j1#Q$4w@sk<)sW02()9#dK+Njd}Gs>YR6Y>$*0Ku{c zkG17>9@O4byD`bIFk7if|Nh{FT7o*SRT~CYjF+IG57{@$4@LX3ztvYfg+9F4V z7IU^=XsC0Y%{+1mR6Roc`NSj}8_jN|XcsOjgfQqf4vEoRJg46E?9jDJ?GAVBVkZX& zmtEfyNICBr7AGbX4=$d?w{m4ZJe*Lx)nTw;1MV}j(hB1>?9xxV;po)!{ua9Xc5>#C z-yB5?Qc)#Nda>`2H>WobIavMq`J9q7-0Y5 zqBhX~@}DhpM;>3{B#`>%_e(5ia>i0iqY8qDX}Tp#y0a(4zr)k}uphF@ZvfyA@6M1g z_uI9Qx;Qy^9$c7B2R7}WfsnVwpcKx)xua}Vey)SSq|{O%h?dSz$FmS#|KovbT(fst~tEMmSj49QJ?FmG*ehVyaol=?94gO&1L(_*qQo^*$z4QyK!gD9Bu5% zA+^5s;Jj!T&wR%`^J!VTd;V*6Lkl>J-CzH#5D8-V<6}u9>$3%a?WKl`brx3i&wWTO zx*3ertqN_3)&D^fXEm4>=rv&Kr%7AoW~x2ps%%|Pi=>4RQN-dKYFlJ?xD0EkF`M1h z){ZW#OgJs$rOCR2K z^;O-C(2kk|@#UX8zeFZ3(&@qcJJ zAv((3vA*u;&guV$wKtE)dTqOhziKz{W>Q4EQX#Y}5+Z3tbY;vOk$EUnk)cUDjfSX1 zWS)mIj}0h7<|#u(X3CiPx6TWD-}m#o@Avb(pZD#bJ;-%^hx0tn<5$s#Y^`ac? zEW+`I=?GZao{&^P$0e#KUkWiA181%f?PQZ&68{ex*4! zWvlo7SLqP#)-jf*#3%4FIk$!!BHUp*OaaLDD^e}>!7nQ;DLDn2sM4iNDMV$Ay7@h1 zSzKIPX;vNX_csc?z&9uoXF!y8#HWN*kD$>F+wlNweh;K>@J$(S!TX>29*gR3b=>Y1 zq-Oz?lz6-=4OyPNmDMX$e}N#T9z!CPn$!csD)vjw3=q{%lNxmh>@8yuJt2Z0x#NL* zzS%#M-l%j4g^1rXs2498G?AqrI-Hwx>()Zas0R|XJJALmSVpTJ7#XQXM%~id8mw37 z#iS2<^3qn4=GeqrGZ-0kjXnJoxqyDy`2WTP=OrelU}R(;^7BJ!ZKGpj5zu2mBFbI- z1GjVLUFl1^v2+=!!uR~SixQ(knazwPXTDlHms?_Ve7p-)Lczz6AG3d2@7cQ-)Uxn- zC4NYE_9tyT7`lV`D5t=2hKPU5hbSp^JPBwNDv5y%wEXD0DQjLnN~8$zWx2EW07Z=g z0Q=9$lCzv#)Rqe5)|{{f$sFs9?5uqmN;JGE4r)y7!F(-(r5_U6PvAJ^FbJ#=@9hLr1paFv2ShW3KyTIQr7k!t1_!%eV_T# za||jt{0y8s3}Mnpa0}v77KTTTPr<4pRB%{z5o;HMo49Cmf;fp}JY@xhe9&~&Zgrxl zsBU4)(@xJveAdX9%lw>exwn2*=-z}M;{B4(PXOueFvJs}Crmr|iaDslp9|~L+w<8o z9son+Hh>?|m2|sNOQ+VYb&#Zgb~@Zle&M_8#Y;B(h|~apBa4O zs5qt}?O$xgr(OAtMW|J$>%HoC&uNbG46#p<+o*sc{hagKA!A z$^SiKYM=gfz8%PQd%Aw>oU|AD;5+lWw(qyv?y0^0{zi6$VYck`q47wj}E5EsQ4j9<8%E{Z^{vxy4{g?B|Uo}T}G~s`#1wPGR;jZ z|IRf1qjuV?F8K5J$*J+a;R0aj9>_T-O?wymW^6QftzP+~at9Ko1n=~jJE5)5RlRQu zHxA-)iIv4On2Hii9-Qx~9vox3( zxD`*L{B_;S4`@F9+2a?-zfVO3?pBuhDXhp%k}t)|`Wf7l;&j#F&7V*CbG+Lg-05cS zR2i@=KI9UwlevTN3!iyAFVCR9)KLQb(?Zl`q$;;InSHuyseJP1)2YU>?kuy2-orDY z>VZ`7;vpD%c3@lc+{3BddPe=?)4ZZ_#qnaz-MQBtG6y?;#4PO}x_DBDY$w-f1$wA1 z2V|+ha+rZC2><2DKk3jIi#=1z7~A_TyyTdbA{ve}$z%@unfr?Qd((0*8%2_z2$3h=T5fWXY%UrlSS}KLo*EaLt}Ow)+nrL{FtSy{tl$MEOZ z+Sc3Gby87Nr5(na3=p5dQH7 zw;?Q%NDPP7P!w3l4VERM!Y~g7&J&DW;E<1?xe_$kSBWRg$WVp+sSi5wVSq;gM|i<` zG#DIcV_d~kcp_>#OVTZ!ZiAQufSf2?c-)U8Cj``|$TU*Jvm>}$lC@(MB!>+YJ< zfU+%ZHn&54QNFk2Kp?^6p~C%;{2W3iVrTRaZi?h-nKh*8ZK{vczHXM@vybRi>|lao zx-PNtavK}0YpaR&f}6Ba2#`|6=DH_$XKH%;TkS7KEGOy9H|%|P)<$!@Y1w74-!r{uvC&$ z;kl@=WrHQZEnNKQE?e7?dG}P%4|ua}2_+Z@I*y08X=TZ{tj||)0iW&;x}$pI$5T?v ze_q_3wy_Vg9VJ!Or|gcuv>hgX=3l(Hr#gjXRH1%$Qsh;Mjd$-QnE}ye#^*sP7RnJ0 zbZ*}R{ao9gx<}?(0H27VFRY6rq4ZVb%^wGl3ZvnWA3rR7+Dfp{;o0oz;o$*vGYl3D zk45mFK@C!_}?&uIPr z{rtcLoL--ZIl*p00En85-+|OWS1zHrZ=j#Gb7xY=zx1qLzBB(V~LMDi0X1zsT+>`p;~ zPu8xFZCV%Lbjbe#vGIt!14e^HLkRN)e$t)F%E}_R;ZlA_e3ePw1Ek6T|G2kr4-udA zJPHBOeOU&JFxC#1_vo!n*B|0-+GOgQF03|a)aRp+?QWUsD0;Z`^NDKeA#2S1&+P<} zc4E`j)rG}a>WQ{G>r*5YadC5VPrJCoY=Z1g`xWNQXzvh5x?t#nh_9b3oJ{bqa7;W! zE()#>uY}ux{Z|qh!lzF(Vp**H)`bKNs1>$!)a}WIX}O#u2i3eD#pm^Vx1yHOWY=Yu zjc-p}uDzB}NScr7RzIORjn_8~LX^=7p;i?>?749|k@~{6x)jmURrGvS!;dv<*P61Q zg3^!!#>Fl5)$meyq5s@L2@;8_YYUbocryYlFCr=kBCjL$5ZvNA(LKtVY_O>}CH^w7 zR!D@5!@93gw8`=x9N&MU@44J+_W87pOrtC)T&Gnt<(vBUP9J(hnu-f;6h`7$5YY4_ zk7g}&zV=l(=iYr|iZi?V&}scck~V%zhT#6yPJDskw;i>^?~(*MVtd#QTD3p)yV5U7 z)MC@W@(=Rcbl2!Sc#`;YAvX% zHLXdrsc+XoA0yeU-M&|ir-)D|=U#3M-P2SpS596&`aA=^v({v%#!d;QWB=YMAhs|a zzp$vcNg`erCdc~w{~;--s~c4G4H?WM4n*u^@l+y>AvOr~xI(l|1t4YcJG^BtYP|A^ z(Kg|P`R4}BOxF4=doFytGwU6VJ2c2$Uh3N}_zIaW5!ztsq!F7`G=+rht4h)T3q>Ff z$CntPNE{YYdHITX&Oy1gSELb+LJp&?-ed60#Je>`()tScmajP1ZB z8j*rsyLOG(jFYMnB0TH>A6WTAs=3#JG()C2DChZC5-vO!8#8^53mok9HMMS&No1qr zEzWyB#ISmZxT5_{r1Nl&hbN3Q>Dof~#BWr=Y1#xmP;A5m++ffnK1FO{5eKY+3J}#6 ztX`Di*N_Mkd*X;^RSXM*7sod_D8Q5Z6FyJ?{tiTGMBYC!4khYuq!l}F-&#(s7#aa3H?7pA;cv5LYQJ0sa1l}uEz<)=3{V_>vM|??7Q??c?VbIh)O$6Jh+}v4WBw$VP zgu}3ey-V;)>rvzwLH+iuQvNC2Wdc8Py{^wdJ1rurU5!c$c20cDlXrt?DJ*3LUBXl_ zZB2VNc0kiaW`PLG+62 z4`myz`y&sWCJv)VUj@l2kT)WeM*loQ>Jc}#3TcO<&o+p^^S1Hud`hY&xqd~LR-~RNpc7FM;&a0NKY;-j8k;H`z1WhOJSrm zi0D#dd5IEVqWUBMbmr*5ipXDkllN6S^t_QexS2e&R#)r=Bz6JVC!jAObiL1tpqCMkQYusBO*T8u>$=59*`7KCm_oDP?07`C}{nBHVBhcw6|@V^`~7 zIC~iXF^qEZtbXD2fpD(Et&axwMSM%+FRdw-+2{`P6Ud4m5JFI)!guflEXXb=NhALP z2*tR5y#e_$n^zb6xMaq@*hbE)sFA%V5F;&rtA&1{6A9Q_XtNVf7wKJ@p8DzDns$;H zSaBK+PI8Yq@NYaAeY|VnMhLxka0BU&br_N|#9|bH4pk z*`M8DQwUJMBlPrA+%1K%tp%y83i&-O3Jysq@IHZCLPik~=`KY2N^R?1HUkSG8`Y>W zSn#_Os!5c*(P@*%`Y}~SMCmq^X-4f zqRuYDR#`-G9^7(D^V7j~M>w}Ns5^Js%CpIIl)KRDaVni__#`1J5CQ>j|0v}1TxP&m zx73h!ozJw7d->jzKeJ?vJ06MD7tD6!sBtcv)yU-JiAkU96*0j7oDL$g`34gc_oSru zGW(mEuX#~DmKaFC)tp{ejT`+#L&QX<=iySypZ`C@J@2GU$?6?e?%q-X6S^oQwY6KB zvDxbX70S6$Ng8@bV~XkB8FX+OO~J9%X{s$`8R?0IXlQS4`o6_ge;?1Z7nkx;!8^}y zi=j7*jOOs0*+;qbQZ)m`^cm&1oo??!FnynXW$RQbx4$l2P?TFsrwdUjZUB{?mv*Dk zG--+^o9rHD|GMfW72_i|?jm#)fC>&e!|H*YP!1r}D`h^45)pRXEqkub(B~axTg#k%%{#C-?1y#_~Wx7C`r)1b5W>7<~`Tp6YtOg z%A@e-Kd8;q>aaP~ruh%R@c#3JEVmsx_ub@EY6^#ZhKhyWq^-d}?LB8!HoTy)KS+De zrg!1r?Y-mwwD+99(QmL>UEMveSS5C^@|u%Bk5x_%mBv|&+NI32f~XW16z86n%wIkD zzxxacKMJIn*0qN|0Jc-&F|(HP3EQo;Cilia-Ms*YEK(MO_Af&3-%x+PL|g-@Yq)y_ z_3#$KRrc!tgR7ty77>wrYo_9$j1ixKxk?H)Sb0qKO;4RduYb6k1-*g7+gx~EvBE5t zm5j0S!R}(gGrQ_{5{%&6QHSXaeh$MbH*ef9fbkWv9{K%kT-%?N#o;@2EM+PJfSt@w zRbVMrM+6rr^;^v#5D!Jt*9Cw}Kt}L1T*zmGRl_G3$`B{9aMVd;d4+7K0U?GLARmVweeEgr%^_ zCS!C6Rt}h8cq0qNosrVQf6r$ma@aRh#}%CXj`A-WOf6ll^i~t)Pf2z6UV_<{-OpTR zYY>u5YvK+;-tCF{hU5)!Sfvq%v4f5oz{r&3xWH~&4Y6Q$$+)8Zi(NinD;*9D)&3QO z9%m1|S48=rm5N$yoe9s7b^!CrK<^3Pg>b|$WCji_4=S=8?1@+%nOrIq_|jY1C{AVI zv;vM(=5B@ zHG)+93~FL26j|K&JON^!I4M(oZosMhemS$uPpZ#&L0FyQ47|*_N$S;NlIai!aQu#3 zw&UchOFwmB+)fNkh;IrCi9j6U#_)TFS?3e-pD@o$ zk2chS`&X;XpV`{tUs7w*--3Wx^!%U}xkuEof|AweC5Y1ufEZ#;*!pcCTToWT=Lu|S zdWaJ)9<_cSF8Z@1fG3Y-TVD;mTOgm25b7hwR;q{b9^KhmjRWzEK7)9OYt-Zc~ zV4(c3)m!N|Z&JIb#BwL7i^yBfGsmC|djvcwblrb}u_#7gEGXoCo*eM*D73d(b*;f5!6WP6QgXL_qt+{*Muwq1ODvcxnLD*b*F+u6kmcNmh16a)``Jem2b z+csLvB9C$8qGhSmts7n6S!RQ}6E;4XBhXnIcnhA95B1X#RtbDiYkz`4{OBV$LxZ

Jij~v~Z#TpC+`eSS zFzS)V(K7q_!kr-O_Bv=1N{@m9s0BC@RMmzwyqU8PIr{jv_g_(wO28i(E`U+Ui- z*Ok`-54B48dxVmp2lr!Tq)Qd>e;OSf<=VGT33zA~ASu1|BAjAR)cE)`#%mzjhgCizX0CBmd?LtbL*Hkd8hA)_0Dk5C;l zF^jNr@1B>>%R6>1vFSiZ02R)fbmjnbcZV=80u!=^JV~I&z>H4BJ|rxa#1fGJ&FJF} zv#+Bz^6eIA5St7NstpE357?7*x8{ukIhc*Y$SEx-ZISby{K6u!UE!L&G5>sn*~t}@ zU;>O6y2~GACu@u`;5TPj^EHyw->rMfN+^*bvVJH%$=i$pg!W4yXjEWJt?w!%o>xbnmdR|4&3wl5@ftz(X!I}z`fvpBWus5cbz1$r zPMJ>%guZn&O;07Gy*@|=1ECbN2m__b&OG`um!GBB(O~pO?0T!#+1ZOPB8q-pJ925w zG@FLC+^`%&5cxRu88_t4$jTZCYOH=f3BUjZ=p!1)e)$vU*3hGRrswWvuG8aG!*ioQzsiUe@}pBLK>}ReY*C! zEm9wX*{u}&GjSZ?7I6991XHJP33aKHrD~Bjk98C?8-}{0c{h3x;sxb-2)lwLE>b5? zR$wh{lO*<)sP&0g7|EqTbEi*|S-$>HalCA7zXWpMtIQ`e4tsli^l>|t5&dZ5z#SWJ zG^!CP{z5=9M+!w+yE^C2ZIi^>lQRf6wV5V7jbpo&#$R*yTsNT=-dOgtJF9#%}5*GXRE0$Biy+#)!NImR| z#IrNEUR8*bUej`>BVrZQ8GMF^6ypfE<(~o$_Xsgk`#(=dH$*S{`5~m{h9)l)?Zewc zgnvzm*XmV^2Eh1_SP^0RUt809svw>>zoaAGH~(aOaX_VYHNY3T_=WN`#Y3B&#m2{H z<9l`ir6&bZj;4yXrb`CHpQI+}X$*I-qYCc&hVJV0iM}Nhzb|nfe~NY3bvh-+xu2{f zaAVJq>5kx2i%m97Pd-O{HYn+3$O#}4mi`BBdt&ZDRb?pkeBVW#4Rph@mJcyLf;U|> zeU?1Lj2eQ*!ksY+u3i|DWBM{PSEm#RN%xa^l;`79`+Cbu?_3C}3H(x|kR0(~?wjA@ zo5}3HT!MI@zf)sUvM`G(xHE~8X+UPt$@q{%kA}p@|6DW^ZCd}q`(XJNq<`M1v;TzN z`f0Vlpw=zjX`AGB$3sI|fPN(2)T`ctDVPo2wQQ8<5-k&ZQ zA+xs(p(D4!D+u%QvY4@DCLm&jRNm&GMyYJGUc-H+XmJ&xL? znzN>+?icx!Xk#c&Cvs|%0#?BuEeW7!^?MgkaXy#pVSLnaC!_`dO-06e#km{s#%W1> zfC19n?b_E@NJ!2%ef=}Qo3zlm2qsQQr+A6;(mPZc69g(h{(i4 zwQLu0+;a98n64l0y|0Xc0PqJsf|zdm?gw?&0}7tW(1VQL&(6D^$_h-pazLHU6Czhg zvVLE|B_RoX(gX7Yd&0gjc8qe>&aD>tYT_yx~kiM=Ft9Xabk5R@Jzz72Fj+B zEQaYI#Eh9palmXM_!KPIqLG<0ro|!$>A-9x1AxY8^B~pp4g;p6f>|$`C~=84BeHb4%EmO1clC)z|Sv-;WF* zArzXH`VvsO2j48)MI1TN4JC$H2@gm~mfT?DL+U6pRAlM$<%#3JsEyx1Qc}%o3*xHU z8s!3jN-Xk)yYCn>_T~(4daQNA)thp&OOS$b5uVe?5U7 zlU?&;_6C60)KjZ*)#N|8*uhdT{T1pI)I!@!vOYH~OX?2617*<06D z)Y8~eap0%b>~7sHgbH;y0W&HkOb!#p3(9Bm;39)_%?4&8ZrCe z`h5R|ElpVy6BDquJ$Y~aZV8a0$ZRlBdR`#j66Y{D4ORj9& z{CdbHfV*@d*eZk*B`VLEM69YY*o-zkH93TqD)a{mD3~hIyd%4nbj2aAIehzG_P?0p zZQEyK7{2%cLu-9+n6&C(o1uJ^UQq-~dGyU^Qrdej#xRgRB^5K@`QMvuSp4tv^N*vW zQ}j{;hc~$mXwTN!e~d~mo$=Cp`0x)&bPGNt$HM;hC^?c)Erg@gXh)aF3vh*5*64Wx z=K_DPK)r+3h3lFPo5Vjz7rw86e385Sm)1wPCGBde4N;hZT`dAheFfAJCAcPVZ48aY z1cu^lZ85szUh428rdVRqQ1>8)`3uoJnyA>1(G;MjxyP8wX|DM;pY!tbquu3S zmMP1`U(k3Op{b}f8$kt00t=?4?)qa>%co%hF4z5^J9O{=>Cj2*UZJ*ieHuOEur9ud z(O98DYh`2Ers%QT`d@jSzB>Uk?$yq5f&LJD&Qc9Y(xZy;fL{#H+jWzjfs_ z(>;`6Urqm}9gvaAdmrbXOtZcOO`~Q<{jr#08wW0W5m0ls+E$ZH`vXA&qfKufs3&*{ zB57JrQoXBa$iLh8o&A5d@e9cs8|mBCQP(hj(M&AMsza{m$=>6#va%RNB_5BRI+j&y zcf!6{b6+ldXUz$jYuq9;X8efo^u*W1myQKIBELDKw0Kwi3rqjfQ6qd*z;QFZLF!ab!MePDcOd8-_5D*Xs?8Iam zGcqM+0e`fBRs9_iTuB87StIF*2w(XXM?M)@LL}6w1cux$`|OO)xA}3|gw}A-OkT?F zyv@fg9j#fWExmGsYIjIrrk3XGFrMG>a~@WY_x4kMQN)QdJUqPi0(`RxWhXYH-<512 z`;}?|h#e>0Um^ToQT5Wm({>MXcBU5TUt|3+(Q|xo?3B}E9a01iWU6`D@%_(;_O)UM zQm0+gh)0Qd%P{OEU`3kUPy|>WOz3}uz2d%p4HX4SU*~jH&%RuvUza-|zMjpCQZ76? zv6}l1+LmT%w*~Dfsps13pC(;QrxQ8D9qQf}Mjz1{Bez8Ds{zFU9=OD<6aQi2fX(2@vM@P}~X`0E! z2(IC>p6J%DFW+a*WC~wMI-Yp63Xr_^n!Q)XB#EJ5N^*6#0wHt`U+MFLL0j+aSEvx+rZVCYM*K!19Wb9ZE^08f3NsXx-)KeGtJrFg@K9RRVrQ8hxMr z*T|I(*17hCoB^sQgQW5Abb(SY$@TLkQxkv(RAIA{5o5%+8-ry%!SX?MND>7MUyuTX zfS-F43Oji+d@ zcaLGhXvAOV&ZEDS)Dx;R*psw1m$AIbE1KQMqhahCdNR z6CZeR=UOpE)yFJN5~r^MmUxdSk25O`K82c7yUQ2{z;LP%Gv&NUuT(ai8`+}Z1(x*p z^q-fmTsdI-em+fi(`1q|B8ZR9w{!vd=z@$GvXI49&Q;(0u*?V#p0KGg;BNWx-&aB- z$LgS>WhP|cO1Nab#;kLT=`vUR6Pgg6ybrs1LN+1XR~PL#^@Z>^(9<>QeBf5_%lNsF zB-|p}Ng(Kc*F~@e%Jb+p)f>nk%8}7Uk~I$xQFkj-%Zng~lQ00}pXb9D1_j#7eVabL z9*7UAL4-SfUHstek@o&U!yN$e282O~sX-*WZfR-hPX9{cE=Fv80v1ggynO)-pQ-!r z)GK$ejy%!PzUs)*5FbDLd>L;vxYAp&*m(hj5k#Pgo1uiuh>U|U;aXrZOK?-@8d6>p z3K?XpfRbx>Lv#X-UrT5AL*8C;JJbS&s4W$sPZFDElFJZt^dRHC0SKo2XLBe~jNwVc zjNaCkmLSZR&*xh*Kji|k3T!y-eKLNLxnbD(*NU?xgTKG4;th1x9yK>)zp=)J~x+D}Eb8qD4>8LFVI3}^Qxnb%~w4q?@#b)-; zb2b6s@`$P4^O={nYxYHbR-i~CoIAs9>W?RYF=(-B1IHQQ&IR6u3>A_ zYO3Yy=gK3Q$u{0x>tNxBMs~~WW;u$FAv=|{jL3_C>-!jrP%`Wug$togLb)psL2(5L zH6p>wnvdzMcL|vsy!BPSOy=hY2i3y*hRQ=K1RbGV-ASjvmMN9c;d8(dZVJQf0kHNn_oc-AiV! zEcl2kt-1@qe-?&(_+oY};ktP9X>drAYef1dZFh-ktx!AI_o+2VLEhx6X>?0l$K0#U zJ^l4)7!M4p?b-uwy@I7T*>~u@5kGP%WBlTKwDn(IY~qrSQVx9DS-Jn+YE(HSYoohooN4>inbs9jH~YBgTxriG&;i=%-MqZPuX;M!w(Ww-qtWU7Cc*J7 zG0mu`zl5w z)_0V@scNsVrYEXD7D{kZ_iJBG|FwN0p;j(E$ZUB-?e41CuWkI=c25#y9D-NEs-0Uz zL@n>J&I8L-br!KJDT1odMEG)ZR=n@nsbEp(v}xjsI!lA#X-Bb^Z{(tBwKeSY4h2G^ zf~@TQ(Bb@51+5o9tUtVXf%?1cy!9_+#vhz9a_ac3ejzbZq>PR~<)oV|4?tcPmJM2t zdk{3(?Pb!t6tQAj-fAqTWIb%!A2&`KUJtI~&xq*lE&DuosppLsBV3t(Xua4$#uO1q z@)?Wxw$lc8j}@N@$hanPswjWky-trX>Y&yl5pv8>mpr1AIf*?r4)5*G^DK@{dq0b) zOkbg5WNY=8V%NgNjh*x@WwE)2MdGH53n6pb$S(0;Q;S#Bs2W!#XM9FfVWs-IKH72e zdCtpnt5!uIbdXqz_F~Vob>iF3-Yv~OzrE{S_&jL`-RDOZjN7+f7pA|r*1w`RS0L(x z{u_wV+BHJF&Qi@ji5;r$HaT+XT$@$dUKoi~tR|n_ERirYd>u{&q*ZMj~^=km-0 zw)a-FwWu0R|B_d$T)f08Y%Be1==AUX1VL$kceH-dQEC4?^cgHt|M+!E*T%sbz!NKd+wa02_u4%cLy0Mdfa5r<* z9}YuWWPrjG7Vd6AP3&U57gKrY&z>J6CW%W%Cn~@OUdp0|b)ed4I)1HJ-8bb|NhDWQVq>Z zwASd#E9R$fHBM%%Y%*sMjL;5XXqD(SlgWrcu9<{luxiLWI1FJs+t$Ibzx!SIm3w|} zsM1HrjV|>anc2%ZM~kag8KLaPz?)BSfXB?mi{CxQeuV)dgJ_I={l?+q2VIE+?=MLA zeCg$7QW33~eT=O)@t*AL(vTUVM^|LYu7)uPYzG}Dp5-%yt*`18iaO;zZ4{npN|R1b zd~jjr&4^#tF<#6B^(4j?4G0gTn|_oRyU8--6scWpD0i;S2JV+E{ghr>X4 z)+Ai-wWkhT$UBFeV(`=9m1Iu?si6P%kFKjc4k?#Bf~%rrtsATjpUwUDR{S=B>~Ru5 z^6N{Jw(XVlNKNcakIi^L&|ua-<)I#y86z-M?@j7FS!Dr;?WMnttVQDT~kh z*odTOc5L+YW%>qHuyC}IT?gQoK%7p~Q^5NCFblD7Llo9t3)Hl=@1VPBp~hptI?khnVl{|hI?RG zP($cK3x?oLVj;;+>gM4B;Y=a+?yCD|cCr6~mqW5)SZN<1^nx&k_`XvqJcGfzx~<=t?Tq*&1IIYHKb2Ic=8>&NwR%O{(tA>E4kpOD?1sAyvEJ1G z@K83D@0aZgJ=<6%cFc6^Teg|dM{BHQkpTT)1%u9}g&qLq^dY$&YL(i{b`*0?W7Nh> zF{kKYrEJx|*gG5N_AJv?EH+d@Lx5=w(bCsHupNmEkxF7D1t`U~fw zSCibPT9MP68IpQ?>4%}ho+K9`(#p5yC&)vnmnba);-i1dOFYr8f zC#k1A*|?lH(M+Z*Vr1L$xfOHzk?;-jeri+~Dd^-|p6aN5Qe}Ecao!~VM5j%Z16hRP zxAAwrGJF$n9sYjJy3U?IU(P;C-?%cISSVt@R6hKc}X&} zCo-*cwFE=l{>u^Kf#*5Yj1DYzQ&?_5YOyt1=!B6ufOLzAkrgZRg2#eClC9E;%tpCvW|! zNq{kEp2RzjwlZVXml_%1SOJhZ{nSeT-}iMFRRjgA@gyn@8F$_vj-NgF!gvt#!PFvy zdT|7u#MqZl^H(15hbk!?-E{H}2~`Xwr5YLH0Z8K!8iyb{lcQWJbm>xoi+M|#o~>ga zA5GzpIQ8O+L(L2f%}IhKs=(8jB~nfZny17@kDzz$^%EQkBg_=0uSn!jr$}#or3 zZsM(t;$4)$lZcSQx&6h{x%*wK_I!zee(^?vt4#d*d2}A9Ald^CsHm4NSyBp8Z1(|C zXE{Xeh!JMfR~0&2BMeWpeV@BY6Wk=Bc}er79#M>VHp!5OIU#8+?$SWQ*hTPo7?&=+ zcT83*0B45Z(+AQdNu0{M;!N+y<*+ABxop_DF$%M`o>#}Cu28bDh(ln!L8hO0$u8(Lb>&}<7%2ng^GSKUBzi`iK00E3Z*CLm zDTb^v!V@F?v$hg!W3S!EFs3b#s3RsOt`zGfb?DF`GP%g*_YE&FNap#-Mnoph!dBt7 zYFIDv#Ihc^^h4!=(@2A@X}W9kPCTEN?mq>M*hRvluo2APKZNp|D1U-HG#jWIY#|?+0i_s{2lY?v5u2axd zcyJJ{fFC90?y`a{fYf_zRotfQm8O=@hLHffoN;)KPvP&#%j?{bkmyPlq&P{W&}rWfp7&YONV47WC3 z6O;7i5Yt$n;?XxOIl1=*sJInW6-E15jV*8$aJTeO-5&G&y?LHvkr<906{>7wlf=Qr z!4Xt*VdU^9*!C$!jCx~|FFIDlrwt8i^Z0xY4i36QBY;CvQz0UBKy6AyLPO|-x8L63 z$61-mkgnV2SH^nKJig-1yvace?=!QjQ{ANC$_@p{RR$3@OYUnHY#Y&_LGDBo7yqf| zMW!$oKKJ1_d;ym*E$mKZ1Z;jxI9|>stN<+KK-(xrIBA(QINVY|YeKSampAuY0R?{Z zk`q+{8@u*pgid+-ryN_pv9tYa-oM`Rh3Z3web?Mg76+ia0*4I6DnqHFNyAG*LPEJP zA8}N7hg?+YYwEzjz)&u;)(P&0(671q`H>_aMeoD$*Ox`h*$0-? zhv$Iwkexrz#r}Q)yYEucyc{{F-o}y`ba7_aeO z=DDoykRFo+N! z6X5WIj^-f*@zCdC-Tq!wTwDb;+Sc&F%?f5_F?5A7S?A&3DJkQjnBN_ebRvVjYW`Fg zOul0?lON^eM&Y`!lL!jDO_I#3Zrea}CoD+PkOu?}4%T2rKE@uSB?P0}=VIS75FAef zmq6#KrB+V)8A0tlPKlkyR|Y$RRnlPB6Tyutfa?au!d%zOJqT zKyZBrSr7Z2jj8M>vdo*rh}h1q#qBtuc|(wEfW#6tK^1tUkV}UXnbHBLMa&h8?(8!f z4>;#l=g{!AGDYdbrcZR_#F>qI!k6aG$jBH3w-aOV56RgLjP4QUt&;{_Yr!Me`+T*- zfoMz;WesSl)zGG0tUCAY>sKN-C3TK!x=l0z`-nUpI$BH|4U&_SV`OGl_9~n&rCUC= z_34ko^U^);X#!UI`Fop6>0NQ>=gy>rxaX80;q}e zWeL1q2*OiksX`3I35m+NaDyhk&F#NnK)={FGt!nYvIM&ARpN|o9$C5U#|3)%lT%Xt zx|sYv^=mo#D-&85Bb54tb;BDo$t`_vO*Qs>;(_emLGfR3 z7pnN)s-re*Dzu`?H+To&(M~#6w&2C7lcoVs3o>K??KSU;}82PePS6N51ij>otk<_*$ zC?PC&$hI1zP?hF=-=zH3en-M``a-5@uaaPB{VZ}Yrn-1*K5yc!^kfBHSSigQG5#&M zql+dzes9ODMq|BH>lY?ksg19n`>7TADfg~y@vWh;w1CmfbL_W zO?k?HbE*xhxEttI*OaFDVtZnQ@JxA41s}&x4`;K#FS~U|P{pXIQ=uBU=PZOq>5<>v zJ9?82eDt5Hztz@t(a`GokI2xqaE(n9dnrl2&v)C`0T5hDnfP+Q&_r=iK|2_GJ|Iq*x6kYoH^Qg z>qSL&st8>k^%{HMczfxRAC@yQo$5wCx4%2xTI7{Mz2pC^qH`Q?(44QYU^^JD{5>YV zVoStthV8ArCrI z&8tJyj2fO+DfUUT%)P@#yhC;I|5y18*jKtyx+mT*S~1kei;qWOISYFGTTf*nvd*(;LxltqCIdG2RTY7_*v*1T> z3G1t^d+X)w{p~YYs={U~T1x#4ZpquU#wT1-I{A!IbZ2!1ThF<lE^iVIn!J|TA_v*^ol||{}qBY&0UQhtc zZxEha856SE*Jo~N6k};*`1^k{Xc=TG*frt6>W2~xpFNeuBSJnpC-+v93&XBoCb_Yy z>1}brHr2rwTw|5XXVC+Tb(uxL+KeV72m zZPj!Hkfm@1o!LkgW5t{B>lNAXo;qj{-cS}3UMFyO78Rok7GYIxAAj)OpgV(6I$E=G zm1tSu+>)g^;+FRPnb)3L8|mrmV;;V;uR%88>eO7c9mU$JIs4tkckDu!rgXGZBqx#{ z>)AM7z%0S(e-}t|XKHYSE{O7ea$=^j^mVs~SDlA)lH#RDf6PUu3G{HJv<>SV9Wu8U$w|A$1RAUEyBM;WyaK8~2)}xR zo_2*;vud7je(s{>%jsR%Is?SR@&M6c$vw3c+4Kh+lPlPJN~MM(Zq6c0R4a+K^CE5a z+W$j)h_)Md)W#*4O|X)}XG+Tk2TI?I-R{?Bof)#MkKO+)bmcnJis8fgJc;A!xjQ4y zKc2hc3*?uz|Ml4Ck2@DN_tSXtQ;P_H8i}@7M;*5MJ%1n_IKI80R; zD8EL9Ylt^K?8p&BEEfKkwvDqZ;sj@2@7DaGs%wc}s&6|3?3xoU&AksDtj(#M^?^=B zC#!DDFF;TfyS;RHWG*Oj?~~wcUYVuh6SK{55l^DenirLuBa-HB{3q5ysL*2viP)*Q zJh@1=by3IdU(KZv+*sJn(|hv_K6Wb#+w^+|ivjssx~zY;@}y>d?(@>n`@4DVx$vRe z=0SxzopYzmqeYt;|KQq-Lhz%CX*jwh6*3Q@uj-dA8E{1K+p` z_Usqv2rrm>7mx5R_*4taR~q#H->nVtZO0bhyi;~tew0OYYpvX};gQO}8gI`=uAF6y z11wJ4wdD`hT_&AP+K87C?eDfk*k}B-R0VNIOk*lCxGL!oZ$p(>-fVaYn;^gYp>WqH z(qlAmj9u_+%s`>($Y|V5szAkVzVYU0>)Tu(&!-nEMp> z13=KAZpdTKEAa%{eAvYRdiXZceDK;@FQyicl6rSbEH5YLCo~{pfX)%m8HmwQmpfym zw?UiUi{Fzr?tk$!+a#glX=Qj!I^R9Z*5Ug6dffHJd`Lrl^K}Pa?5R zg7HHTA;G{t;1;$85&I?>~(s@8If23?0R)QH@* z<%0+D_aTBvl>R&e)-MM8mn{=8vP<~4{6py}&pNUnAmEp~K_lo2fozI>MU3I46(JYw z%3%nNm>n#P1a#l80_Cdt*#HK#C!WY@8Gb%JaU5ulBGP;}abrmS$Huy1nV`ubL^r5Z z+&Bn`kk%36W0lWlYXEqpe;|_4-1g>RM}k(OqgF`pa9eRv#S63Le?NA)FXjv%3Jc3q zE_~yVB>ci1m{be{g9xDU#E}ME_d=_GV3c(mSTUWKqXP)k@NYDCt8e!NBT(LwE-FSo z`0VlB6zok>{!v4mKS4^5dnJv&%;3b7`W5agi{UW71rS3x>4Kd&!MVo~@SUVPAQz=h zOOvK74G@TxFarrDCp29XTt6CZNx{qyFB^Rm1dt{}xdB@znWpGb;B%uAiO>*<*&AIn)xwuHV11bA= z=p)xG`Q`Wo%C6ihd3kv<`Wa;xqlo5eT|GUGILi&^_^(u{wzjmYbw1dh%h}exx>Y?+ z*11XJJuJtM|4Y?{URz6-T3UX+!Y`44j3Ni-3YhSa(V7;5`?Z>;)3ge0*^(;k1S2Eg zEfknbTkfdc%wAhSM=Nrh>Mu_Gx*$oWO5AQME-9%7JiR&mSM;UE{nHSf-HA@`d$@FE z?5~nN{lfwqg&pOlRh}A!A5~aV7AMB0DDc-GtBWIl?^@5ftcuDJrq0Y4&(&BIoViRY z;x}6fP)JX>!}>l(C%<7REG%qIiRixp7=;TC9+YZB=hnWzePECiXp6XF?b|mTCHgziCwDS*yA!wLB?MzuHMq{V199f99Qrbm zIrbmhI#f>dn}d(l-T(JM1vLD98``YQb4AMMe>f$GAVNlv5CdEyfY=$6A_1Bq+1l5x zUmpUxg1KpO8I`S7FHFY=d+MyDqt(jOIvt>YnetCOqCkMw7 z5{1yACTcbi)3}9&(~b=EG^8s7DDi-anptYAmnC~2`0eOv$PsoE3FaiYVqsL0W*XB@ zCBhcr?Fs&JU}K4Bu)e-(@)yM=gXKRW>w)JB&1R+heGAv0duBGTZ5^57p^DKdZ+$&A zj~#R27(eLJKUy0(NK$`OP$8`3B|f}4o>QNZ2xQ{gzo3*V{kNqq>|o^6+-b|svqs$d zFV|1^Onv6wEa!^NjNDNkp-cHM+9d?9YiWyg%UoMK)435BI=uMi(yoT*f(FOWCj}3S zk6aRnI-NVa(sG`xk_eEfixZUh@G^_9P_2v&7pcpiJHq#~N^mOd|GO;za%mI0_cg7f&Y}_ zQC58xI_rbdlfXxJwDwiEzrTGz+O$S-NyNI1&9e{8p9jynm*t801zXM4RZCk7&?6Ck>`*3W-(QflG<8zp z!0ifpnKZ@5RtKpVNtRk$$rK)&rIAo3SU(px@7SdksJ~s!XfuN&i=*`mbw2OemK2qa z2d%yy(Ce^n>lvd5!v?{2l@fn{@9kAf@2h3*Jp4Sm#wdK-zTc~hdS1*G^EH5j?iUp8 z`V_X|!=BJeKo9&0=sC`|9O*=|<7DL|Z4y2o-q6IWd0OD@+y`HpznO94OT+2a4uiAs z;J?%fl=ZDv8(ovPBm~m~1Z7|Dcx2O6MOrOZ~NKX_Cq-Q?1fl54}I397()X zFXl5F2wNB6()J2$EAqOWm1H^}h5yFMFP}y}zLLoxb@nT>^j8!^7LWyaV*BbMC*6%Ser}KBx9Q zwK3!}{eOy+GP<7QXtm;`(Pq7;fro}@*0*Df<>_=b_+I^Sa9=6j*;L~GAw$^W7 z`8er+fqU!0P*(BQwJm*-$}wF*c#f4$$`j-r>K{nLbu{>`53Cn?w~9<90hjPLK5i#2 zg{5Y<68K4_UvaQlOFb>yo8W25t)>0?kZTlsM3q1k&z+f{QWL+Bs$$52x6PFBkC72&*xiO};x|&xV(Qv9NoWCWQ`Y8eh+1T~^I&hviI90dch?7CuGNAC3&I^I#|VeZ22Q{) zE@Ofjk9VPUUMzYzmphg^C&U{sPa!-zbEk11(80)rBIuYRVR1@qvP^n6#@%8wKKu3E zZT-m&%~OOh#vJ<_r>Kb2x+_qy620q|En8YJAqjk$j)=&ZGSp=~H;WnUYf=}I(uzoc zv62Ys9Zq4ytPWJV&&z!0>v=zW7Db2|f}bs}qh~5QZN?S-@#9I+{`XtqbuR}oE(}HK z=AMM8F3|I*_wKy?4Zm^wp@58la!55n_cjW^P|!oTczI>ev6qLb&TI#M;AtfCL8()x zs_7-9)0dmfq}UXxUsUk1Sl|yy&^jxKQ3h(>Ff`JM4IIj@S2jJWpsx}KJIxFh|CVmm z`LGmzDF&5?PJTi{0>5>qf~1_0ky27@-Civ9v z;n`p(9CPnz%pAX5f`+NyugJe#^NI!HE_^$-mXT2xFYY6{?mXwPO3uR|*d}p#4SsB~ zy)R?nzMC%kjs*vpo0r=}UxyOqF&4(RzCQDHJZH{sP(%K@{@D!dBTR+_0bg)srq^L-MD2m;5K~G>$$Aj-iSC${-6+t)>Mc0CoijUwK>Pxcy zeBqXxu35^Lg_fNaT&E^skF>}ABi@@yXPwH#_XleV@ljBN>~;SL_lrAsL2s)&u?Kdz z`(sCuI+WIAE+j+zIG(`Aw7!A6naL*h>R$8^@3`s0HnL$)(fmu-eY?)0xWU7Kz9#+~ zEL;TLKa$1wJ&Pe7E0wrzZT&FOamtB%j|~`6YuBwi51r#lq`JJB^Bd*rRvDqkn3tcQ z|J!fBJ%$=mB@gof4Dkf-BGS}mJ7sG`3%lq`1iS;&F(@jE2mZSXOO`A-nrEKYv0btZ z`T;&GV~{M`SZc67&X@Q_x3?QW4U<2t$zYzc>fj=bU5kOVY$qmfpq0+~e~LS^s4A~) z3)_^ET$5N8D=MH+B?o&^K7 zt!{`C3C4<_cMW-%+8;o`PzgP7w8=93W5nVLCQigGKi`KD{qy66ewmbHN1u5KJn->u z5FbTp$l;spmiy7(8J7;m7Jq!SiZ9!*h-cg1JdoZ%31t}31`~WSEJx|A>+2^9xfD=m zwa4J*KIWx4_=qlDzWfY!Tgbj#gek+eS!!o@6#l@oBP%y9hGx;|p_lYSivK53O;}|d zRjAhC9lZqE~X+5sLH5Kit(s?01-@G22k9p?kZA<9g_Rm}D z*OfcGdo|1&W%n>j?TZF;PI&hQ2WK+18`$9fT^Sf0EUk3L!%e)G0hQ}yt@@zFo3L(L zzCw06_4LzbJ)Zc-k98GFB$12RD?*JaXcs?l3ej@}j@9e5lMc(H77d^C$VU@>DyDDq zNb}M!SZ1LaKf*6de7G4ca$Lp|{&snx;3D?+D)C@Nrr2rJYqWuR>8{TA+h-LgL#V*o zZ&%>cIq%BPW_)9=w_rT%e-(;u`^N2DFA5QOYRp`qteL?xp}&o!bRH3=80Nh&RdAj(EvM?tHZm+0aprL?+bFPrU!jaT>4? z2vbh9{?)Y^M5we^`QyAEt_~f;q>YCUA7YDrPZ(A(y`KK4yK!@xWsXPF)#^_yt*nHR zCIjY#35b>0Iv~3=X#TaDw(|aZ`uhHU{l?B%X1w4$azxe1sHi9~kR55)uU*UE8k7C5 zcH&v5^ef_Mf9cXClRN#&@`}u!#GK=6(}2!`g5SeMIHEw@owPKw&*LX|V*3im4eFvI z`LltwkpvKQyB!!dC16tb3^6&jnHR)Q9T0TgtC~FViH3lYfQa@t@7|wlKkRyOOiJ#M zkPwlv37r%Qv#7v@lU{vrwo!edI=Ii+v;nQ=+k8I%V9%!>%-~V1gT5ihlG|bRjD!4w z0yS0%s<0`%u1_OdWV5~~7*ZV*MpQ!_I--D_a&+>AKF+`M#IV2aVO;jt;lnR*SPx|> za|#+Awd5Tr>cJ`}ZaKY|aHTOiBOLFNu0Go?cDjFoJWu@d{h$ep{AerJdIJKh67E(D z?>R7N%I{9M>vlX-nOx9U&-%cp9>sC{73<(g$Ya7)H)=_(P#1jv;)j0uedSab?0*zc z`u-^-P?i7tfO)%K|Nl4DB>{!Fx-T5m{HA(vPIR4Mju+v859~gJb6aNrrmA_^zHG49#`xGrzjC; zA3;a>$HZdCUD0fC>{Ay*fV>1}WW(c()Z15g=*F0z?~~tmQ`fNDMMZCR4yCHdH|OvZ zh77yj^s1{G?~>(s|AqIhKDEE>>4^Rtn^GukJ9lQteK>iDowE}tuaJ_3+_p+LQ-Alwi3b z+ym-1=#X=WC+wLXs{vyxWQtP<##yGiYIK_M-y>G64!$RhSo|A&b7i;=9?OV?>LHG| zzH;cTQFC*rjad>mn&?9Cx5w7M&8Sz>577-zAQ}Ix=xWP;|A6o6uH6ZI{`#SNEaQdd zk*#HhCiEXYJw0IXdB*Ru=a2k`$AcHpay~5|_~g!MpFX)Zy;$cADl_q%udS`EM@fxY%`HM8HxGu?Il%c`3qyCDWMZZmAaAps`Es`a zVSM|sgctds+|iQskyVLfI^chXU1!TRP(^^I5mTCU1~~ixs0V zhlJWq^lqydZnmayD3r>^e>mxGp4IqnMKc_q^^8)%_Ip2f{3U!rx@cL@8kfnU*x0KF zc&*6siJaw^CYgge7C$J3l81R!^5a1L1enDOW_0RoXAR}@k*)wVA|aypq;70A$v3$c zGIyE{>F+(~*iu}$CqvB@nKB4AmOW?TLYr|MM34dXOk-x7=$ttjG|}P0*A5cY+4`5g z8kPxBTYqQs53(Upw@!5V*lx?V3%!*>SrYGOKt6m*dbH1oB<1pDk!|>BEmenmopRIY z{N&%+D-hUJbP`aUfNGP)zb4$A*<#fVgglYi!!K=VyhcY-aYLe(>U7&_P@u6qcGzIw{%; zJqzEb(cX{mQLuhzp&_#ppLAx0JoptU&6)F#xkIbB;%$9&i!6%hT+x%bR z=%MiNF-n+>K|M}+yLXpuDZ+4j9X>GP`HG%!Z9x=HOKU2m;LkWtHc7HvDpvSxuT1un zSgOe_%5j!DNtpRUX{BiN;JhzN33mOC(Xka;cBgfzIz<8|vN17{096n-BH@ekmA;tg zM{oQK-~+S?>fgX2&fC8b&prx2;c8mEt@&Au!)wD(^U1dO4eOnRtd8}2XIQST1KUPo zF)!t@?>R+1Q7ZO263@*qVGb~0$$VR5n6qE2gYv#)eFO$o+$`xq{mQh~!?PeZDGnO3E%OF50FR7WFPi=X2} ze~;b~R%1>Kc@u4W^RTZ-v}Fc*+OxCgWWSJaO6GI-=oju8y24DC2~XsGDB6xwZ1HDe z*}Hb_Wa*M-5Jk1{?Q$Ukql~xrvPC~(FJ{#{p)U%Eu>*TyS5i6JS@jpXT zAY4yqzjt~1Aiw>c4{a!nOIRr{SnXr*KB1Dpui@9EZ?A79*Ku`OIT*k6c*Q!LH!X8z z0N;m$MoPLSQ+non`K1GrIq{-}Id8XY*|Lz*tByuuqNsX}_Ge{`6SOB z`TF|x>km%b>eE6jcszhq?5h6iYe}L^DNf^^r=2li#LShbs z(X6ecs$)Or<>!xmHKVOPa#PkiDcv}UN^9%MbfU~u*6gXbG+U(Hm9*GK{&saejvdz# z&1K<>1{Ip)FJX{Z0hKilk-U{^9hSW7$mw3Tw3w=_DYtE?`$I~tEG@N_Yd9_=Cr!G$ zd-v`zMu9{2cYD0O`eV0XV`dsmg*E!-wH>ZzjQm=;gq1MJ`;(Sdq_*jr<#l)R^G)!z zw^GrEH0EFAXN@#5xzA_RQLf>!VyYpIPk6D87;nwu+GP5+1{c(tsI1qlS(E%iqmLJj zXw`jlBQb`;)h(EQPsplgzh3a5EghJV0D;#^^|my-do`8cCa^6VBC-v@!Rfso38`^q z*@N7xW4xL8c=G_zMvIk7ipG#+N~s$yF(4Lq$(0`8$jyyliZX<9&XTU*;yIZ%0O_0-&?vTj0{|st5;G2_pe0$Y}a=K3k&%zsd0{ zuN0bI;pVoT^nh43!Xov{(Ijc6+@U{5M?cpt$jdX%9cs!QYmm3Ku}ri(xu!9pp_&hk zvEliaNqhU9j!PJAlEvNH0zuz5Ysf~rhjj!|GgTp~ttB;HP1%D`)y{Bp zVRcz)a*#rvWrKDkl=lw@3>f#ImARdN>49p?rFyhdrIHFQX!UKI*hGOmW{Xy!ajtdb zlslenYGgta+?wv5bjkL&bWQ;!Qg5d?2>g=`TB9lG4J@1fW>VLd^M&dyFEFSull-H) zy4p<;F$ZK2ljSq9WC}JG(U9cJt}Z(uvMW_UYER;O8uzbpC#mQ1xS7?owW<2Fb+~=a zLt8`R9dEu9){APP(LIN+oE%nVz_JjNKt=S+@Ogr;s56LEzpf5-io-Fl$N0(j^%)uCD&6Y2nd4Xd zdoE9D&xjPnuPYjjn_e`a2|UlG_vxL1y*{_B)@-FCboABtdk0XrpYrYE5;0UFG?`g# z%ZtPc@RE7VXJ7GGgB0C5~vT(oXR_yzrA)_o?+-(rVqt*zZd5!$Qkzb z3MKfn1C?-)MQ+!wmmpO>>%aW*qkDkKab*5yS!AlyP|37Ff8)`wSn7@LMvvx!Ik;fd z;^T!!*4K=~6@GveSB*YF)md5X?d_&|#b+JHDQ8r-jhZs$N|`T5Msc|N#Ud_Ni*Hxr z7%X${GcsmkQF5zg^Q+mvU35Pvqs%A}qP&`j*N4vxX=`hxB6EM2rQ=3s1%rX4tZ3*Q zj+!<^zh1q{MfPt0GIs8^Sry*)seSiVNl4n#s1Le4c%`?NsYb_pZ{L8eSANQQv`M*R zi2De2yh$I9v10IVg7A%@!MNbJs_rQ@y zAY?4?5KaM~bTl)H*htH`BsIk!8rJ{2z=%?{YiU*bt%}N0V0z|}JdL^)+nvd6ZQ^Q9 zKQ1Y{GvOeQJ;uP`#EnGnN+qt=-KKgMsl5x5hvb7x6iusZN;zqfnlb_5M1e`R#l@Cj zmw=DrQ4C8xD%nl;yH`OX{TIZWCDEC=J!#aTEeo*)20|rwjYqM8KDQ%O@DzTSsLg2v z+eJH;+iXgB}vNFEG08VEOR(x4ZnD zNHWHx_2&%2Tqg0APM9D|N*&WRh0|Xo4hDqgJz-Zwz zEw>%T=(5e=u8nBr5c%lM_or43U0WpfQ>w`fw8Q1^M4`n3d$F?VdXsU&^ZQiC7_EQA zs$Y_rsDP&+{Vi|&=kKN_CNs<5Ra7oKE!rIE-z-~pMCt-t3lN$6o4ma~OX=4vqa<)* zPWDXG-x5Gi$@LZi<}8v(nY2J#$Ed-O4(9JYw^icMDc^;OIRRt#p5>2AEE$)Wah#1; zU@VS0*k)a5Z{J>U`ply(1ryd&alrrbVifA~D>+g#i!;(mK%b^Lo4%D=_{)J7Cu?d} z*=`^|PD8Tjz-yI;4VSZF#6@ai*=abgZN2%~S8IR2w#O2ymvukO@N2koK7s`li=1s$ zdKtocQL`ztH)MOJ^J7}sxu5wg5*&8&_fl&3Qn@jIlj4Zv6NX`LNiE zJ97o!EUNdv+J*IGkYCIavRY>*gevdDOmw#in{@29=PB+_&YoYXCoWEhBM3#!Di$^Z zkhz>3{EuXstT~)u9KpR%-O%8n(B?T>aD@8WIo-E){SqQH87qAb6!^s%KAr}7^G)dO zOj)sFg=JB_9)!Rh$rj)>vHW%)ik-RAZDa;c6ow{}Z7kb4*oPbHVu z-n-V4Wo2a*fhJ%5k{jgZPJPLzPK7$JM>=Lv8A}Boj)Bf!JN>&C6UBkuPqpDyP3_E8W2J{?QQAt~^(DRfGPxM*HA&pI0Z*wZ??3LbgBMwu?D}ZH`QC+r5ri z1Cqq33T1)S6#?(MVNZq=dbE{s{J1!7sW^5}UC2DWFiUcs`)u5b|Yiw>|-{v6mL3ctIn* zSHF~jd3rkkun{$`A3eN6sKOwf_B$*z}_mgT~ z`pwkTj7PG-k)f5xJzV>&xdd^MBWF zS|lEa>eEe9P&=h91*yGMp*>P5QxSMSr`Yu?32ygj2VEkER>}~T*T3)8m$6Z;MV5Vg z*&VnVJf>DtN$_->E7vNCrmRHF1rG|)X#qVbh%goY7oNh*Fpiu2a=!Kd2vW%dBz6X9 YS$*NJyJzxw`A? Date: Fri, 20 Mar 2026 12:42:58 +0100 Subject: [PATCH 44/62] Merge mbridge distillation for any_model (#1036) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Merge anymodel mbridge distillation ## Summary by CodeRabbit * **New Features** - Heterogeneous-layer support for Puzzletron AnyModel → Megatron-Core conversions - Knowledge-distillation workflow support and tooling - Export of trained models to HuggingFace format (new export helper) - Bridge support extended for Llama and Qwen3 architectures - New CLI options to control HuggingFace export destination and template model * **Tests** - Added integration test validating the distillation + export flow - Removed an obsolete trivial test * **Documentation** - Updated README to document HF export flags * **Chores** - CI: switched example-tests workflow to use NeMo and updated pip invocation --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/workflows/_example_tests_runner.yml | 7 +- .github/workflows/example_tests.yml | 10 +- .../puzzletron/mbridge_distillation/README.md | 2 +- .../mbridge_distillation/distill_hf.py | 3 + .../puzzletron/export/mbridge/__init__.py | 35 ++++ .../torch/puzzletron/export/mbridge/base.py | 142 +++++++++++++ .../export/mbridge/distillation_provider.py | 190 ++++++++++++++++++ .../export/mbridge/export_mbridge_to_hf.py | 89 ++++++++ .../torch/puzzletron/export/mbridge/llama.py | 38 ++++ .../torch/puzzletron/export/mbridge/qwen3.py | 38 ++++ .../mbridge_distillation/test_distill_hf.py | 160 +++++++++++++++ tests/examples/puzzletron/test_dummy.py | 18 -- 12 files changed, 705 insertions(+), 27 deletions(-) create mode 100644 modelopt/torch/puzzletron/export/mbridge/__init__.py create mode 100644 modelopt/torch/puzzletron/export/mbridge/base.py create mode 100644 modelopt/torch/puzzletron/export/mbridge/distillation_provider.py create mode 100644 modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py create mode 100644 modelopt/torch/puzzletron/export/mbridge/llama.py create mode 100644 modelopt/torch/puzzletron/export/mbridge/qwen3.py create mode 100644 tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py delete mode 100644 tests/examples/puzzletron/test_dummy.py diff --git a/.github/workflows/_example_tests_runner.yml b/.github/workflows/_example_tests_runner.yml index 32426b49a4..5aa0614c71 100644 --- a/.github/workflows/_example_tests_runner.yml +++ b/.github/workflows/_example_tests_runner.yml @@ -51,14 +51,15 @@ jobs: apt-get update && apt-get install -y git-lfs git lfs install --system - pip install ".${{ inputs.pip_install_extras }}" + # use `python -m pip` instead of `pip` to avoid conflicts with system pip for nemo containers + python -m pip install ".${{ inputs.pip_install_extras }}" if [[ "${{ inputs.example }}" == *"diffusers"* ]]; then echo "Uninstalling apex for diffusers: T5 Int8 (PixArt) + Apex is not supported as per https://github.com/huggingface/transformers/issues/21391" - pip uninstall -y apex || true + python -m pip uninstall -y apex || true fi - find examples/${{ inputs.example }} -name "requirements.txt" | while read req_file; do pip install -r "$req_file" || exit 1; done + find examples/${{ inputs.example }} -name "requirements.txt" | while read req_file; do python -m pip install -r "$req_file" || exit 1; done - name: Run tests run: | echo "Running tests for: ${{ inputs.example }}" diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index ab9c88346c..10e2c298cc 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -56,8 +56,8 @@ jobs: match_pattern: "^DCO$|^linux$" # Wait for DCO and Unit tests / linux to pass delay: 300s - ##### TensorRT-LLM Example Tests ##### - trtllm-pr: + ##### NeMo Example Tests ##### + nemo-pr: needs: [check-file-changes, wait-checks] if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' strategy: @@ -67,7 +67,7 @@ jobs: uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: - docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc5" + docker_image: "nvcr.io/nvidia/nemo:26.02" example: ${{ matrix.example }} pip_install_extras: "[hf,puzzletron,dev-test]" runner: linux-amd64-gpu-rtxpro6000-latest-2 @@ -76,13 +76,13 @@ jobs: example-pr-required-check: # Run even if example tests are skipped if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }} - needs: [check-file-changes, trtllm-pr] + needs: [check-file-changes, nemo-pr] runs-on: ubuntu-latest steps: - name: Required GPU tests did not succeed if: | needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && ( - needs.trtllm-pr.result != 'success' + needs.nemo-pr.result != 'success' )) run: exit 1 diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md index d3420be096..f7dda866e8 100644 --- a/examples/puzzletron/mbridge_distillation/README.md +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -90,7 +90,7 @@ torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf. - Add `--trust_remote_code` if student or teacher checkpoints need HuggingFace custom modeling code. - The distilled Megatron-Bridge checkpoint will be saved to `--output_dir/checkpoints/iter_`. -- Add `--hf-export-path` to automatically export the final checkpoint to HuggingFace format after distillation. When using `--hf-export-path`, you must also provide `--hf-model` to specify the HuggingFace model ID to use as a template for export (e.g., `meta-llama/Llama-3.1-8B-Instruct`). The `--hf-model` should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation). +- Add `--hf-export-path` (or `--hf_export_path`) to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--hf-model` / `--hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation). - For production use, use larger datasets like [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) and train for more iterations. See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices. ## MMLU Evaluation Results diff --git a/examples/puzzletron/mbridge_distillation/distill_hf.py b/examples/puzzletron/mbridge_distillation/distill_hf.py index d21f35ec16..ac703909c2 100644 --- a/examples/puzzletron/mbridge_distillation/distill_hf.py +++ b/examples/puzzletron/mbridge_distillation/distill_hf.py @@ -144,6 +144,7 @@ def get_args(): parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)") # Export arguments parser.add_argument( + "--hf_export_path", "--hf-export-path", type=str, default=None, @@ -153,6 +154,7 @@ def get_args(): ), ) parser.add_argument( + "--hf_model", "--hf-model", type=str, required=True, @@ -307,6 +309,7 @@ def _build_model_provider(hf_path): train_iters=args.train_iters, hf_export_path=args.hf_export_path, hf_model=args.hf_model, + trust_remote_code=args.trust_remote_code, ) except Exception as e: print(f"⚠️ Export failed: {e}") diff --git a/modelopt/torch/puzzletron/export/mbridge/__init__.py b/modelopt/torch/puzzletron/export/mbridge/__init__.py new file mode 100644 index 0000000000..471e68984b --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Bridge adapters for Puzzletron AnyModel checkpoints. + +This module provides bridges for converting Puzzletron AnyModel checkpoints +(heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge. +""" + +# Import to register bridges (side effect) +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin +from modelopt.torch.puzzletron.export.mbridge.llama import ( # noqa: F401 + PuzzletronLlamaAnyModelBridge, +) +from modelopt.torch.puzzletron.export.mbridge.qwen3 import ( # noqa: F401 + PuzzletronQwen3AnyModelBridge, +) + +__all__ = [ + "HeterogeneousBridgeMixin", + "PuzzletronLlamaAnyModelBridge", + "PuzzletronQwen3AnyModelBridge", +] diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py new file mode 100644 index 0000000000..13ea6612af --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/base.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mixin class for bridges that support heterogeneous layer architectures. + +This module provides a mixin class for converting models with block_configs +(heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge. +""" + +import dataclasses +import json +from collections.abc import Callable +from dataclasses import dataclass, fields + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.transformer.spec_utils import ModuleSpec + + +def heterogeneous_layer_spec(config) -> ModuleSpec: + """Get GPT heterogeneous layer spec using Transformer Engine.""" + return get_gpt_heterogeneous_layer_spec(config, use_te=True) + + +@dataclass +class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig): + """Generic provider for AnyModel checkpoints with block_configs.""" + + # Heterogeneous configuration fields + heterogeneous_layers_config_path: str | None = None + heterogeneous_layers_config_encoded_json: str = "" + transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec + + def __getattr__(self, name: str): + """Handle missing attributes for OmegaConf compatibility. + + Returns empty list for per_block_parameters if not yet initialized (before finalize()). + This allows OmegaConf to serialize/deserialize configs without errors. Actual usage + should call finalize() first to set per_block_parameters as a real attribute. + """ + if name == "per_block_parameters": + # Return existing attribute if set, otherwise [] for OmegaConf compatibility + try: + return object.__getattribute__(self, name) + except AttributeError: + return [] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +class HeterogeneousBridgeMixin: + """Mixin for bridges supporting heterogeneous layer architectures (block_configs). + + Must be used with multiple inheritance alongside a model-specific bridge. + Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge) + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: + """Convert HF AnyModel config to Megatron GPTModelProvider. + + This method: + 1. Calls the parent bridge's provider_bridge() to get a GPTModelProvider with all + model-specific settings (e.g., LlamaBridge sets normalization="RMSNorm", etc.) + 2. Converts the provider to a dict and filters to only fields accepted by + GenericHeterogeneousProvider (which inherits from GPTModelProvider, so all valid + GPTModelProvider fields are preserved) + 3. Adds heterogeneous configuration and returns GenericHeterogeneousProvider + + All parameters from the parent bridge (e.g., LlamaBridge) are maintained because + GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all + the fields that the parent bridge sets. + """ + + parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] + + provider_kwargs = dataclasses.asdict(parent_provider) + + # Filter to only fields that GenericHeterogeneousProvider accepts. + # GenericHeterogeneousProvider inherits from GPTModelProvider, so it includes all + # GPTModelProvider fields. Model-specific fields from subclasses (e.g., MistralModelProvider, + # GPTOSSModelProvider) are filtered out because GenericHeterogeneousProvider only inherits + # from GPTModelProvider, not from model-specific subclasses. + # + # Note: This logic may not work for bridges like MistralBridge or GPTOSSBridge if they + # use model-specific parameters not supported by GenericHeterogeneousProvider (e.g., + # scale_factor, yarn_rotary_scaling_factor, moe_* parameters). In such cases, create a + # model-specific heterogeneous provider that inherits from the model-specific provider. + valid_fields = {f.name for f in fields(GenericHeterogeneousProvider)} + + # Only keep kwargs that are valid fields + provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields} + + provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( + self._build_heterogeneous_config_json(hf_pretrained.config) + ) + return GenericHeterogeneousProvider(**provider_kwargs) + + def _build_heterogeneous_config_json(self, hf_config) -> str: + """Build heterogeneous layers config JSON from HF config.""" + + hf_config_dict = json.loads(hf_config.to_json_string()) + + mcore_block_configs = [ + self._convert_block_config(block) for block in hf_config_dict["block_configs"] + ] + return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False) + + def _convert_block_config(self, block: dict) -> dict: + """Convert a single block config from HF format to MCore format.""" + return { + "attention": self._convert_attention_config(block["attention"]), + "ffn": self._convert_ffn_config(block["ffn"]), + } + + def _convert_attention_config(self, attention_config: dict) -> dict: + """Convert attention config from HF format to MCore format.""" + attention_config = attention_config.copy() + attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads") + return attention_config + + def _convert_ffn_config(self, ffn_config: dict) -> dict: + """Convert FFN/MLP config from HF format to MCore format.""" + ffn_config = ffn_config.copy() + ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size") + return ffn_config diff --git a/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py new file mode 100644 index 0000000000..fa49dc29c5 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Upstream this fix to Megatron-Bridge and remove this local copy. + +import logging +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, Optional + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider +from megatron.bridge.models.transformer_config import TransformerConfig +from megatron.core.models.gpt import GPTModel as MCoreGPTModel + +import modelopt.torch.distill as mtd +import modelopt.torch.distill.plugins.megatron as mtd_mcore + +if TYPE_CHECKING: + from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig + + +logger = logging.getLogger(__name__) + + +@dataclass +class DistillationProvider(TransformerConfig): + """Provider for Megatron Core GPT models in distillation mode. + + Please use `convert_to_distillation_provider()` to create an instance of this class. + """ + + teacher: Optional[GPTModelProvider | MambaModelProvider] = None + kd_config: Optional["ModelOptDistillConfig"] = None + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "Use `convert_to_distillation_provider()` to create an instance of this class." + ) + + def __post_init__(self): + assert getattr(self, "teacher", None) is not None, "Teacher model must be provided." + + shared_attrs = [ + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "context_parallel_size", + "seq_length", + "pipeline_dtype", + ] + for attr in shared_attrs: + if getattr(self, attr) != getattr(self.teacher, attr): + raise ValueError(f"Student and teacher providers must have the same {attr}.") + + # Logits are overwritten in-place when TE cross-entropy loss is enabled, so switch it back to native version. + self.cross_entropy_fusion_impl = "native" + + # Hack to dynamically subclass other providers and still use their methods + self._super_class = self.__class__.__bases__[0] + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Configure and instantiate a ModelOpt DistillationModel based on this configuration. + + Args: + pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage + post_process: Whether to include post-processing in the model, defaults to last pipeline stage + vp_stage: Virtual pipeline stage + + Returns: + MCoreGPTModel: Configured ModelOpt DistillationModel instance + """ + if vp_stage is not None: + raise ValueError("ModelOpt KD currently does not support virtual-pipeline parallel.") + + assert self.teacher is not None, "Teacher model must be provided." + student_model = self._super_class.provide(self, pre_process, post_process, vp_stage) # type: ignore[attr-defined] + + # Finalize teacher provider before creating model (required for heterogeneous models). + # + # per_block_parameters is an attribute of HeterogeneousTransformerConfig (defined in + # MCoreHeterogeneousTransformerConfig, heterogeneous_config.py:197). It's created during + # provider creation (bridge.to_megatron_provider()), but finalize() ensures they're consistent + # with current parallelism settings and distributed context. Student model creation (above) + # initializes parallel_state (process groups, TP/PP config), which weight loading/scatter + # requires. During teacher model creation, get_config_for_layer() is called (transformer_block.py:341) + # for each layer, which uses per_block_parameters and current tensor_model_parallel_size to + # determine layer architecture. Without finalize() in this context, architecture expectations + # don't match checkpoint weights, causing: + # ValueError: ProcessGroupNCCL::scatter: invalid tensor size at index 0 + # (expected (2880, 4096), got (3584, 4096)) + # + # Note: This explanation needs to be confirmed yet. + self.teacher.finalize() + + # Hack to get teacher's pre-wrap hooks called to potentially load HF weights + teacher_model = self.teacher.provide_distributed_model( + wrap_with_ddp=False, mixed_precision_wrapper=None + )[0] + + kd_cfg = mtd_mcore.setup_distillation_config( + self.kd_config, student_model.config, teacher_model.config + ) + modelopt_cfg = { + "teacher_model": teacher_model, + "criterion": kd_cfg.criterion, + "loss_balancer": kd_cfg.loss_balancer, + } + kd_model = mtd.convert(student_model, mode=[("kd_loss", modelopt_cfg)]) + mtd_mcore.adjust_distillation_model_for_mcore(kd_model, kd_cfg) + + return kd_model + + def to_cfg_dict(self) -> dict[str, Any]: + """Custom method to save equivalent to the original provider class. + + Used by `_ConfigContainerBase` to serialize the main `ConfigContainer` to YAML. + There is no need to restore a `DistillationProvider` from the run config file, as + it can always be re-converted using the original student provider. + + Returns: + Dictionary representation of this provider class + """ + from megatron.bridge.training.utils.config_utils import _ConfigContainerBase + + result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} + + # Include all fields from the original provider class (self._super_class), not just DistillationProvider + # This ensures fields like heterogeneous_layers_config_encoded_json are preserved + excluded_fields = {"teacher", "kd_config"} + for field in fields(self._super_class): + if field.name.startswith("_") or field.name in excluded_fields: + continue + # Only include if the field exists on this instance (it should, since we converted from the original provider) + if hasattr(self, field.name): + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + + # Also include any additional fields from DistillationProvider itself (if any) + for field in fields(self): + if field.name.startswith("_") or field.name in excluded_fields: + continue + # Skip if already included from _super_class + if field.name not in result: + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + + return result + + def __setattr__(self, name, value): + super().__setattr__(name, value) + # Mirror to teacher if it has that attribute + if hasattr(self.teacher, name): + setattr(self.teacher, name, value) + + +def convert_to_distillation_provider( + student_provider: GPTModelProvider | MambaModelProvider, + teacher_provider: GPTModelProvider | MambaModelProvider, + kd_config: Optional["ModelOptDistillConfig"] = None, +) -> "DistillationProvider": + """Convert a given model provider to a DistillationProvider.""" + + assert isinstance(student_provider, (GPTModelProvider, MambaModelProvider)), ( + "Student provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + assert isinstance(teacher_provider, (GPTModelProvider, MambaModelProvider)), ( + "Teacher provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + + DistillationProvider.__bases__ = (type(student_provider),) + student_provider.__class__ = DistillationProvider + + student_provider.teacher = teacher_provider + student_provider.kd_config = kd_config + student_provider.__post_init__() + + return student_provider diff --git a/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py new file mode 100644 index 0000000000..59e1d7dade --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export utilities for Megatron-Bridge checkpoints.""" + +import shutil +from pathlib import Path + +from megatron.bridge import AutoBridge + +from modelopt.torch.utils import print_rank_0 + + +def export_to_hf_and_copy_config( + student_hf_path: str, + checkpoint_dir: str, + train_iters: int, + hf_export_path: str, + hf_model: str, + trust_remote_code: bool = False, +) -> None: + """ + Export Megatron checkpoint to HuggingFace format and copy config.json from student model. + + TODO: This script should not be needed (manually copying config.json from + student model to exported model). Remove it once export_to_hf() in AutoBridge + supports copying/preserving config.json from student model. + + + Args: + student_hf_path: Path to the original student HuggingFace model (source of config.json) + checkpoint_dir: Base directory where Megatron checkpoints are stored + train_iters: Number of training iterations (used to construct final checkpoint path) + hf_export_path: Directory path where the HuggingFace model will be saved + hf_model: HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct) + trust_remote_code: Whether to trust remote modeling code when loading the HF template model + """ + print_rank_0(f"\n{'=' * 80}") + print_rank_0("Exporting to HuggingFace format...") + print_rank_0(f"{'=' * 80}\n") + + # Construct path to final checkpoint iteration (format: iter_0000100 for 100 iterations) + final_iter_dir = Path(checkpoint_dir) / f"iter_{train_iters:07d}" + print_rank_0(f"📂 Using final checkpoint: {final_iter_dir}") + + # Use the final iteration directory for export (export_ckpt will validate it exists) + megatron_path = str(final_iter_dir) + + # Create bridge using standard model ID (not AnyModel checkpoint) to avoid sharding structure issues + print_rank_0("🌉 Creating bridge...") + print_rank_0(f" Using model ID: {hf_model}") + bridge = AutoBridge.from_hf_pretrained(hf_model, trust_remote_code=trust_remote_code) + + print_rank_0("📤 Exporting to HuggingFace format...") + # Use strict=False for test_distill_hf.py which uses small models (2 layers) with fewer layers + # than the template model (32 layers). This allows partial exports when some tensors are missing. + # Note: This is NOT needed when running on real compressed puzzletron student models, + # which have the same number of layers as the template model (some may be skipped via no_op + # in block_configs, but all layer tensors are still present in the checkpoint). + bridge.export_ckpt( + megatron_path=megatron_path, + hf_path=hf_export_path, + show_progress=True, + strict=False, # Needed for test_distill_hf.py small models; not needed for real compressed models + ) + + print_rank_0(f"✅ Successfully exported model to: {hf_export_path}") + + # Copy config.json from student model to exported model (preserves block_configs) + student_config_path = Path(student_hf_path) / "config.json" + exported_config_path = Path(hf_export_path) / "config.json" + + print_rank_0(f"📋 Copying config.json from student model: {student_config_path}") + shutil.copy(student_config_path, exported_config_path) + print_rank_0(f"✅ Copied config.json to: {exported_config_path}") + + print_rank_0(f"\n{'=' * 80}") + print_rank_0("Export complete!") diff --git a/modelopt/torch/puzzletron/export/mbridge/llama.py b/modelopt/torch/puzzletron/export/mbridge/llama.py new file mode 100644 index 0000000000..b802215298 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/llama.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Llama-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.llama.llama_bridge import LlamaBridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import LlamaForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) +class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge): + """ + Megatron Bridge for Puzzletron Llama-based AnyModel checkpoints. + + Extends LlamaBridge with support for heterogeneous layer architectures (block_configs). + All Llama-specific settings are inherited from LlamaBridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses LlamaBridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from LlamaBridge diff --git a/modelopt/torch/puzzletron/export/mbridge/qwen3.py b/modelopt/torch/puzzletron/export/mbridge/qwen3.py new file mode 100644 index 0000000000..ace20fbf89 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/qwen3.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Qwen3-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.qwen.qwen3_bridge import Qwen3Bridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import Qwen3ForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) +class PuzzletronQwen3AnyModelBridge(HeterogeneousBridgeMixin, Qwen3Bridge): + """ + Megatron Bridge for Puzzletron Qwen3-based AnyModel checkpoints. + + Extends Qwen3Bridge with support for heterogeneous layer architectures (block_configs). + All Qwen3-specific settings are inherited from Qwen3Bridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses Qwen3Bridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from Qwen3Bridge diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py new file mode 100644 index 0000000000..7b0c9a32f6 --- /dev/null +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for distill_hf.py script.""" + +from pathlib import Path + +import torch +from _test_utils.examples.run_command import extend_cmd_parts, run_example_command +from _test_utils.torch.distributed.utils import get_free_port +from _test_utils.torch.puzzletron.utils import create_and_save_small_hf_model, create_tokenizer +from transformers import AutoModelForCausalLM + +from modelopt.torch.puzzletron.anymodel import convert_model + + +def test_distill_hf(project_root_path: Path, tmp_path: Path): + """Integration test for distill_hf.py. + + Creates Llama models programmatically, converts them to heterogeneous format (AnyModel), + and runs mbridge distillation. The models are created with reduced size for faster testing. + Models are converted to include block_configs. + """ + # Prepare student and teacher models + student_hf_path, teacher_hf_path = _prepare_student_and_teacher_models( + project_root_path, tmp_path + ) + + output_dir = tmp_path / "distill_output" + hf_export_dir = tmp_path / "hf_export" + + # Build command-line arguments for distill_hf.py + nproc_per_node = torch.cuda.device_count() + tp_size = nproc_per_node + train_iters = 5 + + cmd_parts = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--master-addr", + "127.0.0.1", + "--master-port", + str(get_free_port()), + "distill_hf.py", + "--use_mock_data", + ] + extend_cmd_parts( + cmd_parts, + student_hf_path=student_hf_path, + teacher_hf_path=teacher_hf_path, + output_dir=str(output_dir), + tp_size=tp_size, + pp_size=1, + seq_length=128, + split="99,1,0", + mbs=1, + gbs=4, + train_iters=train_iters, + lr=0.0001, + min_lr=1e-5, + lr_warmup_iters=2, + eval_interval=100, + eval_iters=0, + log_interval=5, + hf_export_path=str(hf_export_dir), + hf_model="meta-llama/Llama-3.1-8B-Instruct", + ) + + run_example_command(cmd_parts, example_path="puzzletron/mbridge_distillation") + + # Check that distillation checkpoint contains run_config.yaml + run_config_path = output_dir / "checkpoints" / f"iter_{train_iters:07d}" / "run_config.yaml" + assert run_config_path.exists(), f"Expected run_config.yaml to exist at: {run_config_path}" + + # Verify that the distilled model can be loaded in HuggingFace format + model = AutoModelForCausalLM.from_pretrained( + str(hf_export_dir), + local_files_only=True, + trust_remote_code=True, + ) + assert model is not None, "Failed to load distilled model with AutoModelForCausalLM" + + print( + f"PYTEST SUMMARY: test_distill_hf test has finished successfully. " + f"Output directory: {output_dir}, HF export: {hf_export_dir}" + ) + + +def _prepare_student_and_teacher_models(project_root_path: Path, tmp_path: Path) -> tuple[str, str]: + """Prepare student and teacher models for distillation. + + Creates Llama models programmatically, converts them to heterogeneous format (AnyModel), + and returns the paths to the converted checkpoints. + + Args: + project_root_path: Path to the project root directory + tmp_path: Temporary directory for test artifacts + + Returns: + Tuple of (student_hf_path, teacher_hf_path) as strings + """ + + # Create temporary directories for models + student_hf_dir = tmp_path / "student_hf" + teacher_hf_dir = tmp_path / "teacher_hf" + + # Create tokenizer (uses local tokenizer from test resources) + tokenizer = create_tokenizer(project_root_path) + + # Create student model using utility function (loads config from Hub). + # TODO: Make the student model using different ffn sizes across layers. + create_and_save_small_hf_model( + output_path=str(student_hf_dir), + vocab_size=tokenizer.vocab_size, + tokenizer=tokenizer, + hf_model_name="meta-llama/Llama-3.1-8B-Instruct", + hybrid_override_pattern=None, + ) + + # Create teacher model (same as student for testing) + create_and_save_small_hf_model( + output_path=str(teacher_hf_dir), + vocab_size=tokenizer.vocab_size, + tokenizer=tokenizer, + hf_model_name="meta-llama/Llama-3.1-8B-Instruct", + hybrid_override_pattern=None, + ) + + # Convert models to AnyModel format BEFORE distillation + # This is needed as converted checkpoints will be used as input for distillation later + student_anymodel_dir = tmp_path / "student_anymodel" + teacher_anymodel_dir = tmp_path / "teacher_anymodel" + + convert_model( + input_dir=str(student_hf_dir), + output_dir=str(student_anymodel_dir), + converter="llama", + ) + + convert_model( + input_dir=str(teacher_hf_dir), + output_dir=str(teacher_anymodel_dir), + converter="llama", + ) + print("Models converted to AnyModel format:") + print(f" Student AnyModel: {student_anymodel_dir}") + print(f" Teacher AnyModel: {teacher_anymodel_dir}") + + return student_anymodel_dir, teacher_anymodel_dir diff --git a/tests/examples/puzzletron/test_dummy.py b/tests/examples/puzzletron/test_dummy.py deleted file mode 100644 index d07694471a..0000000000 --- a/tests/examples/puzzletron/test_dummy.py +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def test_dummy(): - assert True From 2b6572c00641b1e995ab07751ca3078440df0577 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 20 Mar 2026 15:23:38 +0100 Subject: [PATCH 45/62] =?UTF-8?q?MR=20branch=20for=20the=20remaining=20dif?= =?UTF-8?q?ference=20between=20dkorzekwa/any=5Fmodel=20an=E2=80=A6=20(#104?= =?UTF-8?q?7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? The remaining difference between dkorzekwa/any_model and dkorzekwa/anymodel_mbridgedist ## Summary by CodeRabbit * **New Features** * Added an end-to-end MIP sweep workflow that runs sweeps, collects metrics, and writes CSV reports. * **Improvements** * Simplified attention implementation selection and clarified attention display wording. * Realization pipeline now returns realized model solution paths. * Caching behavior adjusted for token ID retrieval. * Test model generation scales layer count with distributed size. * **Bug Fixes** * Checkpoint hook state restored using the correct load API. * **Chores** * Extended pre-commit license hook exclusion. * **Tests** * Added GPU/distributed test fixtures and dtype/checkout helpers. --------- Signed-off-by: Daniel Korzekwa Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../deci_lm_hf_code/configuration_decilm.py | 10 +- .../puzzletron/mip/mip_and_realize_models.py | 4 +- modelopt/torch/puzzletron/mip/sweep.py | 297 ++++++++++++++++++ .../puzzletron/utils/checkpoint_manager.py | 2 +- .../torch/puzzletron/utils/data/dataset.py | 2 +- modelopt/torch/puzzletron/utils/parsing.py | 6 +- tests/_test_utils/torch/puzzletron/utils.py | 4 +- tests/gpu/torch/conftest.py | 59 ++++ 8 files changed, 369 insertions(+), 15 deletions(-) create mode 100644 modelopt/torch/puzzletron/mip/sweep.py create mode 100644 tests/gpu/torch/conftest.py diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py index c37b9adaf7..6ff0e26a4e 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py @@ -20,7 +20,7 @@ import warnings from typing import Any -from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available +from transformers.utils import is_flash_attn_2_available # , is_torch_sdpa_available from .block_config import BlockConfig from .transformers_4_44_2__configuration_llama import LlamaConfig @@ -119,12 +119,8 @@ def _delete_per_layer_attributes(self): def _choose_llama4_attn_implementation(self, llama4_attn_implementation): self.llama4_attn_implementation = llama4_attn_implementation if self.llama4_attn_implementation is None: - if is_torch_sdpa_available(): - _print_once("auto-setting llama4_attn_implementation to sdpa") - self.llama4_attn_implementation = "sdpa" - else: - _print_once("auto-setting llama4_attn_implementation to eager") - self.llama4_attn_implementation = "eager" + _print_once("auto-setting llama4_attn_implementation to sdpa") + self.llama4_attn_implementation = "sdpa" def _choose_llama3_attn_implementation(self, kwargs: dict[str, Any]) -> str: attn_implementation = kwargs.pop("attn_implementation", None) diff --git a/modelopt/torch/puzzletron/mip/mip_and_realize_models.py b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py index e241021ec9..17d8e4a2db 100644 --- a/modelopt/torch/puzzletron/mip/mip_and_realize_models.py +++ b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py @@ -38,7 +38,7 @@ def launch_realize_model(cfg: DictConfig): validate_puzzle_solutions(args=cfg.realize_model) -def launch_mip_and_realize_model(cfg: DictConfig): +def launch_mip_and_realize_model(cfg: DictConfig) -> list[str]: # Determine device for distributed operations (NCCL requires CUDA tensors) device = "cpu" if dist.size() > 1: @@ -69,3 +69,5 @@ def launch_mip_and_realize_model(cfg: DictConfig): cfg.realize_model.solutions_path = Path(solution_path) launch_realize_model(cfg) dist.barrier() + + return solution_paths diff --git a/modelopt/torch/puzzletron/mip/sweep.py b/modelopt/torch/puzzletron/mip/sweep.py new file mode 100644 index 0000000000..82046934bc --- /dev/null +++ b/modelopt/torch/puzzletron/mip/sweep.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MIP sweep functionality for exploring multiple memory compression rates.""" + +import json +from pathlib import Path + +import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools.logger import mprint + + +def get_teacher_memory_from_subblock_stats(hydra_cfg) -> float: + """Calculate teacher model memory from subblock_stats.json. + + Replicates the MIP solver's memory calculation logic: + - Loads subblock_stats.json which contains memory measurements for all subblock configs + - Finds the teacher FFN subblock (with full intermediate_size) + - Finds the teacher Attention subblock (full attention, not no_op) + - Calculates: non_block_memory + (ffn_memory + attention_memory) * num_layers + + This matches how the MIP solver computes total model memory via _get_block_stats(). + + Args: + hydra_cfg: Hydra configuration object + + Returns: + Total teacher memory in MiB + """ + puzzle_dir = Path(hydra_cfg.puzzle_dir) + + # Read config.json directly from the teacher model path + teacher_dir = Path(hydra_cfg.teacher_dir) + config_file = teacher_dir / "config.json" + + with open(config_file) as f: + config_dict = json.load(f) + + num_layers = config_dict["num_hidden_layers"] + teacher_ffn_intermediate = config_dict["intermediate_size"] + teacher_num_kv_heads = config_dict["num_key_value_heads"] + + # Get the MIP configuration + mip_subblock_args = hydra_cfg.mip.subblock_stats_args[0] + batch_size = mip_subblock_args["batch_size"] + weights_dtype = str(mip_subblock_args["weights_dtype"]) + activations_dtype = str(mip_subblock_args["activations_dtype"]) + kv_cache_dtype = str(mip_subblock_args["kv_cache_dtype"]) + + # Load subblock_stats.json + subblock_stats_path = puzzle_dir / "subblock_stats.json" + if not subblock_stats_path.exists(): + raise FileNotFoundError( + f"subblock_stats.json not found at {subblock_stats_path}. " + "Please run the full pipeline first without --mip-only flag." + ) + + with open(subblock_stats_path) as f: + subblock_stats_list = json.load(f) + + # Find the entry matching our MIP configuration and teacher's n_embd + matching_stats = None + for stats_entry in subblock_stats_list: + args = stats_entry["args"] + if ( + args["batch_size"] == batch_size + and args["weights_dtype"] == weights_dtype + and args["activations_dtype"] == activations_dtype + and args["kv_cache_dtype"] == kv_cache_dtype + and args.get("n_embd") == config_dict["hidden_size"] + ): + matching_stats = stats_entry + break + + if matching_stats is None: + raise ValueError( + f"No subblock_stats entry found for batch_size={batch_size}, " + f"dtypes=({weights_dtype}, {activations_dtype}, {kv_cache_dtype}), " + f"n_embd={config_dict['hidden_size']}" + ) + + # Get non-block memory (embeddings, LM head, etc.) + total_memory = matching_stats.get("non_block", {}).get("memory_mib", 0.0) + + # Find the teacher FFN and Attention subblocks + # Note: Each subblock is EITHER attention OR ffn, not both + # We need to find BOTH and add their memory together + teacher_ffn_subblock = None + teacher_attention_subblock = None + + for subblock in matching_stats.get("subblocks", []): + subblock_class = subblock.get("subblock_config_class", "") + subblock_config = subblock.get("subblock_config", {}) + + # Check for FFN subblocks with teacher's intermediate_size + if "FFN" in subblock_class: + ffn_size = subblock_config.get("intermediate_size") + if ffn_size == teacher_ffn_intermediate and not subblock_config.get("no_op", False): + teacher_ffn_subblock = subblock + + # Check for Attention subblocks with teacher's num_key_value_heads + elif "Attention" in subblock_class: + kv_heads = subblock_config.get("num_key_value_heads") + if kv_heads == teacher_num_kv_heads and not subblock_config.get("no_op", False): + teacher_attention_subblock = subblock + + if teacher_ffn_subblock is None: + raise ValueError( + f"Could not find teacher FFN subblock with intermediate_size={teacher_ffn_intermediate}" + ) + + if teacher_attention_subblock is None: + raise ValueError( + f"Could not find teacher Attention subblock with num_key_value_heads={teacher_num_kv_heads}" + ) + + # Calculate total teacher memory: non_block + (ffn_memory + attention_memory) * num_layers + per_layer_memory = teacher_ffn_subblock["memory_mib"] + teacher_attention_subblock["memory_mib"] + total_memory += per_layer_memory * num_layers + + return total_memory + + +def extract_solution_results( + solution_path: Path, + target_memory_mib: float, + teacher_memory_mib: float, + compression_rate: float, +) -> dict: + """Extract results from a completed MIP solution. + + Args: + solution_path: Path to the solutions.json file (not the directory!) + target_memory_mib: Target memory constraint used for MIP + teacher_memory_mib: Teacher model memory in MiB + compression_rate: Compression rate applied + + Returns: + Dictionary containing extracted metrics + """ + result = { + "compression_rate": compression_rate, + "target_memory_mib": target_memory_mib, + "teacher_memory_mib": teacher_memory_mib, + } + + # solution_path is the path to solutions.json file, get parent directory + solution_dir = solution_path.parent + + # Load solutions.json for actual memory and parameters + solutions_file = solution_dir / "solutions.json" + with open(solutions_file) as f: + solutions_data = json.load(f) + solution = solutions_data[0] # First solution + total_costs = solution.get("total_costs", {}) + result["actual_memory_mib"] = total_costs.get("stats.memory_mib", None) + result["num_params"] = total_costs.get("stats.num_params", None) + + # Load solution_0.json for accuracy metrics + validation_dir = solution_dir / "solutions--validation" + # TODO: There could be multiple solutions, but we only need the first one. Is it the best solution? + solution_0_file = validation_dir / "solution_0.json" + + with open(solution_0_file) as f: + validation_data = json.load(f) + result["lm_loss"] = validation_data.get("lm_loss", {}).get("avg", None) + result["token_accuracy_top_1"] = validation_data.get("token_accuracy_top_1", {}).get( + "avg", None + ) + result["token_accuracy_top_5"] = validation_data.get("token_accuracy_top_5", {}).get( + "avg", None + ) + result["token_accuracy_top_10"] = validation_data.get("token_accuracy_top_10", {}).get( + "avg", None + ) + + return result + + +def write_results_to_csv(results: list, output_csv: str): + """Write sweep results to CSV file. + + Args: + results: List of result dictionaries + output_csv: Path to output CSV file + """ + import csv + + # Define CSV columns in desired order + columns = [ + "compression_rate", + "target_memory_mib", + "actual_memory_mib", + "teacher_memory_mib", + "num_params", + "lm_loss", + "token_accuracy_top_1", + "token_accuracy_top_5", + "token_accuracy_top_10", + ] + + # Write CSV + output_path = Path(output_csv) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=columns) + writer.writeheader() + writer.writerows(results) + + mprint(f"Results written to: {output_path}") + + +def run_mip_sweep(hydra_cfg): + """Run MIP for multiple memory compression rates and generate CSV with results. + + This function is called when mip.sweep.enabled is True in the config. + + Args: + hydra_cfg: Hydra configuration object with mip.sweep settings + """ + mprint("=" * 80) + mprint("MIP Sweep Mode Enabled") + mprint("=" * 80) + + # Get sweep configuration + sweep_cfg = hydra_cfg.mip.sweep + compression_rates = sweep_cfg.memory_compression_rates + output_csv = sweep_cfg.output_csv + puzzle_dir = Path(hydra_cfg.puzzle_dir) + + mprint(f"Compression rates: {compression_rates}") + mprint(f"Output CSV: {output_csv}") + mprint(f"Puzzle directory: {puzzle_dir}") + + # Calculate teacher memory from subblock_stats + teacher_memory = get_teacher_memory_from_subblock_stats(hydra_cfg) + mprint( + f"Teacher memory (from subblock_stats): {teacher_memory:.1f} MiB ({teacher_memory / 1024:.1f} GiB)" + ) + + # Collect results + all_results = [] + + # Run MIP for each compression rate + for compression_rate in compression_rates: + target_memory_mib = teacher_memory * compression_rate + mprint("\n" + "=" * 80) + mprint( + f"Running MIP for compression_rate={compression_rate:.2f} " + f"(target={target_memory_mib:.1f} MiB = {target_memory_mib / 1024:.1f} GiB)" + ) + mprint("=" * 80) + + # Modify config dynamically + hydra_cfg.mip.human_constraints.target_memory = target_memory_mib + + # Run MIP and realize models (reuse existing distributed logic!) + solution_paths = mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + + # Extract results (only on master rank) + if dist.is_master(): + for solution_path in solution_paths: + result = extract_solution_results( + solution_path=Path(solution_path), + target_memory_mib=target_memory_mib, + teacher_memory_mib=teacher_memory, + compression_rate=compression_rate, + ) + all_results.append(result) + + mprint( + f"✓ Results: actual_memory={result['actual_memory_mib']:.1f} MiB, " + f"lm_loss={result['lm_loss']:.4f}" + ) + + # Write results to CSV (only on master rank) + if dist.is_master(): + mprint("\n" + "=" * 80) + mprint("MIP Sweep Complete - Writing Results") + mprint("=" * 80) + write_results_to_csv(all_results, output_csv) + mprint(f"Completed {len(all_results)} sweep runs") + mprint("=" * 80) diff --git a/modelopt/torch/puzzletron/utils/checkpoint_manager.py b/modelopt/torch/puzzletron/utils/checkpoint_manager.py index 3fc4bf87e2..90303e2de9 100644 --- a/modelopt/torch/puzzletron/utils/checkpoint_manager.py +++ b/modelopt/torch/puzzletron/utils/checkpoint_manager.py @@ -135,7 +135,7 @@ def load_hook_states(self, activation_hooks) -> bool: loaded_count = 0 for module_name, hook in activation_hooks.items(): if module_name in hook_states: - hook.load_state(hook_states[module_name]) + hook.load_state_dict(hook_states[module_name]) loaded_count += 1 # Log progress info if available (only for a few hooks to avoid spam) diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py index a71049105e..fffc2a3a1d 100644 --- a/modelopt/torch/puzzletron/utils/data/dataset.py +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -287,7 +287,7 @@ def permute( # this is expensive so we cache it -@functools.cache +@functools.lru_cache(maxsize=None) def get_fim_token_ids(tokenizer): # ugly fix for Salesforce/codegen25-7b-multi tokenizer if hasattr(tokenizer, "encoder"): diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py index 97f698ba91..ff5bb6963a 100644 --- a/modelopt/torch/puzzletron/utils/parsing.py +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -150,9 +150,9 @@ def _format_attention_config(attention_config) -> str: if attention_config.no_op: return "❌ no_op" - n_heads = attention_config.n_heads_in_group - if n_heads is not None: - return f"{n_heads} heads in group" + num_kv_heads = attention_config.num_key_value_heads + if num_kv_heads is not None: + return f"{num_kv_heads} kv heads" if attention_config.replace_with_linear: return "linear replacement" diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index 8b7711c3cb..b5e32566de 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -129,14 +129,14 @@ def create_and_save_small_hf_model( config.vocab_size = vocab_size config.hidden_size = 256 config.intermediate_size = 512 - config.num_hidden_layers = 2 + config.num_hidden_layers = max(2, dist.size()) config.num_attention_heads = 32 config.num_key_value_heads = 8 config.max_position_embeddings = 512 # Fix layer_types to match num_hidden_layers (newer transformers validates this) if hasattr(config, "layer_types") and config.layer_types is not None: - config.layer_types = config.layer_types[:2] + config.layer_types = config.layer_types[: config.num_hidden_layers] # Fix rope_scaling to be consistent with max_position_embeddings if hasattr(config, "rope_scaling") and config.rope_scaling is not None: diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py new file mode 100644 index 0000000000..a38322d141 --- /dev/null +++ b/tests/gpu/torch/conftest.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.distributed as dist +from _test_utils.torch.distributed.utils import init_process + +import modelopt.torch.opt as mto + + +@pytest.fixture +def distributed_setup_size_1(): + init_process(rank=0, size=1, backend="nccl") + yield + dist.destroy_process_group() + + +@pytest.fixture +def need_2_gpus(): + if torch.cuda.device_count() < 2: + pytest.skip("Need at least 2 GPUs to run this test") + + +@pytest.fixture +def need_8_gpus(): + if torch.cuda.device_count() < 8: + pytest.skip("Need at least 8 GPUs to run this test") + + +@pytest.fixture +def need_4_gpus(): + if torch.cuda.device_count() < 4: + pytest.skip("Need at least 4 GPUs to run this test") + + +@pytest.fixture(scope="module") +def set_torch_dtype(request): + orig_dtype = torch.get_default_dtype() + torch.set_default_dtype(request.param) + yield + torch.set_default_dtype(orig_dtype) + + +@pytest.fixture(scope="session", autouse=True) +def enable_hf_checkpointing(): + mto.enable_huggingface_checkpointing() From 110316a09fce748963cb05f1b3755a1ab7db0219 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Mar 2026 14:58:54 +0100 Subject: [PATCH 46/62] Dkorzekwa/decilm hf code cleanup (#1071) ### What does this PR do? - Delete unused decilm code ## Summary by CodeRabbit ## Release Notes * **Removals** * Removed model conversion utilities for Llama-Nemotron format * Removed DeciLM model classes, tokenizer implementations, and configuration utilities * Removed checkpoint import/export functionality * Removed heterogeneous transformer layer specifications and configuration builders * **Updates** * Updated pre-commit configuration for additional file exclusions * Updated imports across modules to reflect removed dependencies --------- Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 2 + .../nemo_export/convert_hf_to_nemo.py | 98 -- .../nemo_export/convert_nemo_to_hf.py | 96 -- .../puzzletron/decilm/conversion_utils.py | 157 --- .../converters/convert_llama3_to_decilm.py | 153 --- .../megatron_lm__megatron_tokenizer.py | 148 --- .../deci_lm_hf_code/megatron_lm__tokenizer.py | 187 --- .../decilm/deci_lm_hf_code/modeling_decilm.py | 888 +------------- .../deci_lm_hf_code/tokenization_decilm.py | 195 ---- .../puzzletron/export/MCore/llama_nemotron.py | 1015 ----------------- .../export/MCore/llama_nemotron_utils.py | 729 ------------ .../MCore/puzzletron_hf_config_utils.py | 142 --- .../export/MCore/puzzletron_layer_specs.py | 928 --------------- .../replacement_library.py | 2 - .../init_child_from_parent.py | 8 +- .../puzzletron/tools/checkpoint_utils_hf.py | 135 --- .../puzzletron/tools/post_init_sparse.py | 4 +- .../tools/sharded_checkpoint_utils.py | 72 -- ...validate_puzzle_with_multi_replacements.py | 40 +- 19 files changed, 7 insertions(+), 4992 deletions(-) delete mode 100644 examples/puzzletron/nemo_export/convert_hf_to_nemo.py delete mode 100644 examples/puzzletron/nemo_export/convert_nemo_to_hf.py delete mode 100644 modelopt/torch/puzzletron/decilm/conversion_utils.py delete mode 100644 modelopt/torch/puzzletron/decilm/converters/convert_llama3_to_decilm.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/tokenization_decilm.py delete mode 100644 modelopt/torch/puzzletron/export/MCore/llama_nemotron.py delete mode 100644 modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py delete mode 100644 modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py delete mode 100644 modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 546423fa77..b278013bb8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -113,6 +113,8 @@ repos: examples/speculative_decoding/main.py| examples/speculative_decoding/medusa_utils.py| examples/speculative_decoding/server_generate.py| + examples/puzzletron/evaluation/lm_eval_anymodel.py| + modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py| experimental/dms/models/qwen3/configuration_qwen3_dms.py| experimental/dms/models/qwen3/modeling_qwen3_dms.py| )$ diff --git a/examples/puzzletron/nemo_export/convert_hf_to_nemo.py b/examples/puzzletron/nemo_export/convert_hf_to_nemo.py deleted file mode 100644 index 0cf16b4486..0000000000 --- a/examples/puzzletron/nemo_export/convert_hf_to_nemo.py +++ /dev/null @@ -1,98 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -from pathlib import Path -from typing import Any - -from nemo.collections import llm - -from modelopt.torch.puzzletron.export.MCore.llama_nemotron import ( - PuzzletronLlamaNemotronModel, - PuzzletronNemotronModelConfig, -) - - -def convert_model( - hf_model_path_local: str, output_path_nemo_local: str, overwrite: bool = False -) -> Any: - """Convert a Puzzletron HuggingFace model to NeMo format. - - Args: - hf_model_path_local: Path to the input Puzzletron HuggingFace model directory - output_path_nemo_local: Path where the converted Puzzletron NeMo model will be saved - overwrite: Whether to overwrite existing output directory - """ - - model = PuzzletronLlamaNemotronModel(config=PuzzletronNemotronModelConfig) - # NOTE: API call to import_ckpt is here: https://github.com/NVIDIA-NeMo/NeMo/blob/294ddff187f68c055d87ffe9400e65975b38693d/nemo/collections/llm/api.py#L888 - print( - f"calling import_ckpt with model: {model}, " - f"source: {hf_model_path_local}, " - f"output_path: {output_path_nemo_local}, " - f"overwrite: {overwrite}" - ) - nemo2_path = llm.import_ckpt( - model=model, - source="hf://" + hf_model_path_local, - output_path=Path(output_path_nemo_local), - overwrite=overwrite, - ) - - print(f"Model saved to {nemo2_path}") - return nemo2_path - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Convert Puzzletron HuggingFace model to NeMo format" - ) - parser.add_argument( - "--input-ckpt-path", - "-i", - type=str, - required=True, - help="Path to the input Puzzletron HuggingFace model directory", - ) - parser.add_argument( - "--output-ckpt-path", - "-o", - type=str, - required=True, - help="Path where the converted Puzzletron NeMo model will be saved", - ) - parser.add_argument( - "--overwrite", - action="store_true", - default=False, - help="Whether to overwrite existing output directory (default: False)", - ) - - args = parser.parse_args() - - # Validate input path - if not os.path.exists(args.input_ckpt_path): - raise FileNotFoundError(f"Input model path does not exist: {args.input_ckpt_path}") - - # Create output directory if it doesn't exist - os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True) - - print(f"Converting model from {args.input_ckpt_path} to {args.output_ckpt_path}") - convert_model(args.input_ckpt_path, args.output_ckpt_path, args.overwrite) - - -if __name__ == "__main__": - main() diff --git a/examples/puzzletron/nemo_export/convert_nemo_to_hf.py b/examples/puzzletron/nemo_export/convert_nemo_to_hf.py deleted file mode 100644 index 4645ae5b43..0000000000 --- a/examples/puzzletron/nemo_export/convert_nemo_to_hf.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -from pathlib import Path -from typing import Any - -from nemo.collections import llm - -from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import copy_deci_lm_hf_code - - -def convert_model( - nemo_model_path_local: str, output_path_hf_local: str, overwrite: bool = False -) -> Any: - """Convert a NeMo model to HuggingFace format. - - Args: - nemo_model_path_local: Path to the input NeMo model file (.nemo) - output_path_hf_local: Path where the converted HuggingFace model will be saved - overwrite: Whether to overwrite existing output directory - """ - - # NOTE: API call to export_ckpt is here: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/llm/api.py#L987 - print( - f"calling export_ckpt with path: {nemo_model_path_local}, " - f"target: hf, output_path: {output_path_hf_local}, " - f"target_model_name: PuzzletronLlamaNemotronModel, " - f"overwrite: {overwrite}" - ) - - hf_path = llm.export_ckpt( - path=nemo_model_path_local, - target="hf", - output_path=Path(output_path_hf_local), - target_model_name="PuzzletronLlamaNemotronModel", - overwrite=overwrite, - ) - - copy_deci_lm_hf_code(hf_path) - - print(f"Model saved to {hf_path}") - return hf_path - - -def main() -> None: - parser = argparse.ArgumentParser(description="Convert NeMo model to HuggingFace format") - parser.add_argument( - "--input-ckpt-path", - "-i", - type=str, - required=True, - help="Path to the input NeMo model checkpoint", - ) - parser.add_argument( - "--output-ckpt-path", - "-o", - type=str, - required=True, - help="Path where the converted Puzzletron HuggingFace model will be saved", - ) - parser.add_argument( - "--overwrite", - action="store_true", - default=False, - help="Whether to overwrite existing output directory (default: False)", - ) - - args = parser.parse_args() - - # Validate input path - if not os.path.exists(args.input_ckpt_path): - raise FileNotFoundError(f"Input model path does not exist: {args.input_ckpt_path}") - - # Create output directory if it doesn't exist - os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True) - - print(f"Converting model from {args.input_ckpt_path} to {args.output_ckpt_path}") - convert_model(args.input_ckpt_path, args.output_ckpt_path, args.overwrite) - - -if __name__ == "__main__": - main() diff --git a/modelopt/torch/puzzletron/decilm/conversion_utils.py b/modelopt/torch/puzzletron/decilm/conversion_utils.py deleted file mode 100644 index deb080ea21..0000000000 --- a/modelopt/torch/puzzletron/decilm/conversion_utils.py +++ /dev/null @@ -1,157 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import re -from collections import defaultdict - -from safetensors.torch import load_file, save_file -from tqdm import tqdm - - -def convert_name(name): - return name.replace("feed_forward", "mlp").replace("language_model.", "") - - -def convert_routed_experts_weight(llama_name, weight): - assert ".experts." in llama_name, "Only use this func to convert weights of routed experts" - llama_name_prefix = llama_name.split(".experts.")[0] - deci_name_prefix = convert_name(llama_name_prefix) - - experts_state_dict = {} - for i_expert, expert_weight in enumerate(weight.unbind(dim=0)): - expert_prefix = f"{deci_name_prefix}.experts.{i_expert}" - if "gate_up_proj" in llama_name: - gate_weight, up_weight = expert_weight.transpose(0, 1).chunk(2, dim=0) - experts_state_dict[f"{expert_prefix}.gate_proj.weight"] = gate_weight.contiguous() - experts_state_dict[f"{expert_prefix}.up_proj.weight"] = up_weight.contiguous() - elif "down_proj" in llama_name: - down_weight = expert_weight.transpose(0, 1) - experts_state_dict[f"{expert_prefix}.down_proj.weight"] = down_weight.contiguous() - else: - raise ValueError(f"Unknown expert weight: {llama_name}") - - return experts_state_dict - - -def get_layer_subblock(param): - if param.startswith("model.embed_tokens."): - return "embeddings" - if param.startswith("lm_head.") or param == "model.norm.weight": - return "lm_head" - m = re.match(r"model\.layers\.(\d+)\.(.+)", param) - if m: - layer, suffix = m.groups() - if suffix.startswith(("self_attn.", "input_layernorm.weight")): - return f"block_{layer}_attention" - elif suffix.startswith(("mlp.", "post_attention_layernorm.weight")): - return f"block_{layer}_ffn" - return None - - -def convert_model_weights_to_decilm(llama_hf_dir, output_dir, is_llama4=False): - index_path = os.path.join(llama_hf_dir, "model.safetensors.index.json") - single_file_path = os.path.join(llama_hf_dir, "model.safetensors") - - # Check if we have a sharded model (with index) or single file model - if os.path.exists(index_path): - # Sharded model - use existing logic - with open(index_path) as f: - index = json.load(f) - param_to_file = index["weight_map"] - all_param_names = list(param_to_file.keys()) - elif os.path.exists(single_file_path): - # Single file model - create a synthetic index - data = load_file(single_file_path) - all_param_names = list(data.keys()) - param_to_file = dict.fromkeys(all_param_names, "model.safetensors") - else: - raise FileNotFoundError( - f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." - ) - name_map = { - name: convert_name(name) - for name in all_param_names - if name.startswith("language_model.") or not is_llama4 - } - - # Reverse map: file -> set of params - file_to_params = defaultdict(set) - for name, file in param_to_file.items(): - file_to_params[file].add(name) - - # Determine subblocks needed - subblocks = defaultdict(list) - for old_name, new_name in name_map.items(): - subblock = get_layer_subblock(new_name) - if subblock: - subblocks[subblock].append((old_name, new_name)) - - # Output directory - out_dir = os.path.join(output_dir, "subblocks_safetensors") - os.makedirs(out_dir, exist_ok=True) - - # New weight index - new_index = {"metadata": {"format": "pt"}, "weight_map": {}} - - # For single file models, load all data once - if os.path.exists(single_file_path) and not os.path.exists(index_path): - all_data = load_file(single_file_path) - else: - all_data = None - - for subblock, param_pairs in tqdm(subblocks.items(), desc="Processing subblocks"): - tensors = {} - - if all_data is not None: - # Single file model - get tensors from pre-loaded data - for old_name, new_name in param_pairs: - if old_name in all_data: - if ".experts." not in old_name: - tensors[new_name] = all_data[old_name] - else: - experts_state_dict = convert_routed_experts_weight( - old_name, all_data[old_name] - ) - tensors.update(experts_state_dict) - else: - # Sharded model - load only needed files for this subblock - param_files = {param_to_file[old] for old, _ in param_pairs} - for file in param_files: - data = load_file(os.path.join(llama_hf_dir, file)) - for old_name, new_name in param_pairs: - if param_to_file[old_name] == file and old_name in data: - if ".experts." not in old_name: - tensors[new_name] = data[old_name] - else: - experts_state_dict = convert_routed_experts_weight( - old_name, data[old_name] - ) - tensors.update(experts_state_dict) - - # Save this subblock - subblock_file = f"{subblock}.safetensors" - save_file(tensors, os.path.join(out_dir, subblock_file)) - - # Update index - for new_name in tensors: - new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}" - - # Save new index file - with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f: - json.dump(new_index, f, indent=2) - - print(f"✅ Finished saving subblocks and index to {output_dir}") diff --git a/modelopt/torch/puzzletron/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/puzzletron/decilm/converters/convert_llama3_to_decilm.py deleted file mode 100644 index c5f107ea1e..0000000000 --- a/modelopt/torch/puzzletron/decilm/converters/convert_llama3_to_decilm.py +++ /dev/null @@ -1,153 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert a Llama3 model to a DeciLM model.""" - -#!/usr/bin/env python3 -from pathlib import Path - -import torch -from fire import Fire -from transformers import LlamaConfig - -from modelopt.torch.puzzletron.decilm.conversion_utils import convert_model_weights_to_decilm -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer -from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import copy_deci_lm_hf_code - -""" -example: - -python -m scripts.hf.convert_llama3_to_decilm \ - --input_dir .../meta-llama/Meta-Llama-3.1-8B-Instruct \ - --output_dir .../meta-llama/Meta-Llama-3.1-8B-Instruct--deci-hf/ -""" - - -def convert_llama3_config_to_decilm_config(config: LlamaConfig) -> DeciLMConfig: - """Convert Llama3 config to DeciLM config format.""" - print("\n=== Converting Llama3 Config to DeciLM Config ===") - - # Get dtype from config - check both dtype and torch_dtype - # Prefer dtype if it's set (not None), otherwise fall back to torch_dtype - dtype = getattr(config, "dtype", None) - if dtype is None: - dtype = getattr(config, "torch_dtype", None) - - # Convert torch.dtype to string if needed (for JSON serialization) - if dtype is not None and isinstance(dtype, torch.dtype): - dtype = str(dtype).replace("torch.", "") - - # Track which global values will be removed (moved to per-layer configs) - print("\n📝 Converting global values to per-layer block_configs:") - print( - f" - intermediate_size: {config.intermediate_size} → block_configs[*].ffn.intermediate_size" - ) - print( - f" - num_key_value_heads: {config.num_key_value_heads} → block_configs[*].attention.n_heads_in_group (derived)" - ) - print(f" - hidden_act: {config.hidden_act} → block_configs[*].ffn.hidden_act") - print( - f" - sliding_window: {getattr(config, 'sliding_window', None)} → block_configs[*].attention.window_length" - ) - - # Create block configs for each layer - block_configs = [] - for i in range(config.num_hidden_layers): - # Configure attention - attention_config = { - "no_op": False, - "replace_with_linear": False, - "sparsify": None, - "n_heads_in_group": config.num_attention_heads // config.num_key_value_heads, - "window_length": None, # Llama3 doesn't use sliding window by default - "num_sink_tokens": None, # Llama3 doesn't use sink attention - "use_prefill_window_in_sink_attention": False, - "unshifted_sink": False, - "mamba": None, - "llama4": None, # No Llama4 specific attention for Llama3 - } - - # Configure FFN - ffn_config = { - "no_op": False, - "replace_with_linear": False, - "sparsify": None, - "intermediate_size": config.intermediate_size, - "gated": True, # Llama3 uses SwiGLU - "hidden_act": config.hidden_act, - "moe": None, # Llama3 doesn't use MoE - } - - block_configs.append({"attention": attention_config, "ffn": ffn_config}) - - # Create DeciLM config - decilm_config = DeciLMConfig( - block_configs=block_configs, - hidden_size=config.hidden_size, - max_position_embeddings=config.max_position_embeddings, - num_attention_heads=config.num_attention_heads, - num_hidden_layers=config.num_hidden_layers, - tie_word_embeddings=config.tie_word_embeddings, - vocab_size=config.vocab_size, - rms_norm_eps=config.rms_norm_eps, - attention_bias=config.attention_bias, - o_proj_bias=config.attention_bias, # llama3 bias defined by attention_bias - rope_theta=config.rope_theta, - rope_scaling=config.rope_scaling, - position_embedding_type="rope", # Llama3 uses standard RoPE - architectures=["DeciLMForCausalLM"], - auto_map={ - "AutoConfig": "configuration_decilm.DeciLMConfig", - "AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM", - }, - eos_token_id=config.eos_token_id, - bos_token_id=config.bos_token_id, - pad_token_id=config.pad_token_id, - head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), - dtype=dtype, - ) - - print(f"\n✓ Created DeciLM config with {len(block_configs)} layers") - print( - "✓ Global per-layer keys (intermediate_size, num_key_value_heads, hidden_act, sliding_window)" - ) - print(" will be removed from saved config and are only in block_configs") - - return decilm_config - - -def convert_configs_in_dirs(input_dir, output_dir): - """Convert the config of a Llama3 model to a DeciLM model.""" - input_dir = Path(input_dir) - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - input_config_path = input_dir / "config.json" - config = LlamaConfig.from_pretrained(input_config_path) - decilm_config = convert_llama3_config_to_decilm_config(config) - decilm_config.save_pretrained(output_dir) - - -def convert_llama3_to_decilm(input_dir, output_dir): - """Convert a Llama3 model to a DeciLM model.""" - convert_configs_in_dirs(input_dir, output_dir) - copy_tokenizer(input_dir, output_dir) - convert_model_weights_to_decilm(input_dir, output_dir) - copy_deci_lm_hf_code(output_dir) - - -if __name__ == "__main__": - Fire(convert_llama3_to_decilm) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py deleted file mode 100644 index 1b3840a300..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py +++ /dev/null @@ -1,148 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from abc import ABC, abstractmethod -from collections import OrderedDict -from typing import Any - -import numpy - - -class MegatronTokenizer(ABC): - """Abstract class for tokenizer - - Absent a config or class-specific tracking of which objects are uniquely identifying, we must - include all key word arguments as unique identifiers - - Args: - tokenizer_paths (Tuple[str]): All tokenizer source paths or prefixes - - tokenizer_options (Dict[str, Any]): All tokenizer options - """ - - def __init__(self, *tokenizer_paths: str, **tokenizer_options: Any): - self.unique_identifiers = OrderedDict() - self.unique_identifiers["class"] = type(self).__name__ - self.unique_identifiers["tokenizer_path"] = list(tokenizer_paths) - for option in tokenizer_options: - self.unique_identifiers[option] = str(tokenizer_options[option]) - - self.unique_description = json.dumps(self.unique_identifiers, indent=4) - - super().__init__() - - @abstractmethod - def tokenize(self, text: str) -> numpy.ndarray: - """Convert text to embedding ids - - Args: - text (str): The text to convert - - Returns: - numpy.ndarray: The converted embedding ids - """ - - def detokenize(self, ids: numpy.ndarray) -> str: - """Convert embedding ids to text - - Args: - ids (numpy.ndarray): The ids to convert - - Returns: - str: The converted text - - Raises: - NotImplementedError: Non-abstract, optional method - """ - raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__)) - - @property - @abstractmethod - def vocab(self): - """Dictionary from vocab text token to id token""" - - @property - @abstractmethod - def inv_vocab(self): - """Dictionary from vocab id token to text token""" - - @property - @abstractmethod - def vocab_size(self): - """The vocabulary size""" - - @property - def cls(self): - """The CLS token id - - Raises: - NotImplementedError: Non-abstract, optional attribute - """ - raise NotImplementedError("{} has no attribute 'cls'".format(type(self).__name__)) - - @property - def sep(self): - """The SEP token id - - Raises: - NotImplementedError: Non-abstract, optional attribute - """ - raise NotImplementedError("{} has no attribute 'sep'".format(type(self).__name__)) - - @property - def pad(self): - """The PAD token id - - Raises: - NotImplementedError: Non-abstract, optional attribute - """ - raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__)) - - @property - def eod(self): - """The EOD token id - - Raises: - NotImplementedError: Non-abstract, optional attribute - """ - raise NotImplementedError("{} has no attribute 'eod'".format(type(self).__name__)) - - @property - def bos(self): - """The BOS token id - - Raises: - NotImplementedError: Non-abstract, optional attribute - """ - raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__)) - - @property - def eos(self): - """The EOS token id - - Raises: - NotImplementedError: Non-abstract, optional attribute - """ - raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__)) - - @property - def mask(self): - """The MASK token id - - Raises: - NotImplementedError: Non-abstract, optional attribute - """ - raise NotImplementedError("{} has no attribute 'mask'".format(type(self).__name__)) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py deleted file mode 100644 index 5c641d25b9..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py +++ /dev/null @@ -1,187 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors - -"""Megatron tokenizers.""" - -import base64 -import json -from pathlib import Path - -from .megatron_lm__megatron_tokenizer import MegatronTokenizer - - -def reload_mergeable_ranks( - path: str, - max_vocab: int | None = None, -) -> dict[bytes, int]: - """ - Reload our tokenizer JSON file and convert it to Tiktoken format. - """ - assert path.endswith(".json") - - # reload vocab - with open(path) as f: - vocab = json.load(f) - assert isinstance(vocab, list) - print(f"Vocab size: {len(vocab)}") - if max_vocab is not None: - vocab = vocab[:max_vocab] - print(f"Cutting vocab to first {len(vocab)} tokens.") - - # build ranks - ranks: dict[bytes, int] = {} - for i, x in enumerate(vocab): - assert x.keys() == {"rank", "token_bytes", "token_str"} - assert x["rank"] == i - merge = base64.b64decode(x["token_bytes"]) - assert i >= 256 or merge == bytes([i]) - ranks[merge] = x["rank"] - - # sanity check - assert len(ranks) == len(vocab) - assert set(ranks.values()) == set(range(len(ranks))) - - return ranks - - -PATTERN_TIKTOKEN = ( - r"[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" -) -PATTERN_TIKTOKEN_V2 = ( - "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+" - "|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*" - "|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" -) - - -class CustomTikTokenizer(MegatronTokenizer): - def __init__( - self, - path: str, - pattern: str, - vocab_size: int, - num_special_tokens: int, - special_tokens: list[str] | None, - ): - super().__init__( - path, - pattern=pattern, - vocab_size=vocab_size, - num_special_tokens=num_special_tokens, - special_tokens=special_tokens, - ) - import tiktoken - - # if vocab_size is None: - # vocab_size = 2**17 # Fallback vocab size is 131072. - self._vocab_size = vocab_size - - special_tokens_default = ["", "", ""] - if special_tokens is None: - special_tokens = special_tokens_default.copy() - assert len(special_tokens) == len(set(special_tokens)), ( - f"Special tokens should be unique: {special_tokens}" - ) - assert len(special_tokens) <= num_special_tokens < self._vocab_size - assert set(special_tokens_default) <= set(special_tokens), ( - f"Custom special tokens should include {special_tokens_default}" - ) - - special_filler = [ - "".format(id=i) for i in range(len(special_tokens), num_special_tokens) - ] - if special_filler: - print(f"Adding special tokens {special_filler[0]}, ..., {special_filler[-1]}") - special_tokens = special_tokens + special_filler - assert len(set(special_tokens)) == len(special_tokens) == num_special_tokens, special_tokens - inner_vocab_size = self._vocab_size - num_special_tokens - - token_to_id_without_special_tokens = reload_mergeable_ranks( - path, max_vocab=inner_vocab_size - ) - # Create space for special tokens. - token_to_id_without_special_tokens = { - t: i + num_special_tokens for t, i in token_to_id_without_special_tokens.items() - } - - special_tokens = {t: i for i, t in enumerate(special_tokens)} - self._unk_id = special_tokens[""] - self._bos_id = special_tokens[""] - self._eos_id = special_tokens[""] - - # Create tiktoken model. - self._model = tiktoken.Encoding( - name=Path(path).parent.name, - pat_str=pattern, - mergeable_ranks=token_to_id_without_special_tokens, - special_tokens=special_tokens, - ) - - # Create final _id_to_token and _token_to_id data structures with special tokens inserted - # into appropriate locations. - assert set(token_to_id_without_special_tokens.keys()).isdisjoint(set(special_tokens.keys())) - self._token_to_id = token_to_id_without_special_tokens.copy() - self._token_to_id.update(special_tokens) - self._id_to_token = {v: k for k, v in self._token_to_id.items()} - assert set(range(self._vocab_size)) == set(self._id_to_token.keys()) - - @property - def bos(self) -> int: - return self._bos_id - - @property - def eos(self) -> int: - return self._eos_id - - @property - def unk(self) -> int: - return self._unk_id - - @property - def eod(self) -> int: - return self._eos_id - - @property - def vocab(self): - return self._token_to_id - - @property - def inv_vocab(self): - return self._id_to_token - - def tokenize(self, s: str, bos: bool = False, eos: bool = False) -> list[int]: - tokens = self._model.encode_ordinary(s) - if bos: - tokens = [self.bos, *tokens] - if eos: - tokens = [*tokens, self.eos] - - return tokens - - def detokenize(self, tokens: list[int]) -> str: - return self._model.decode(tokens) - - @property - def vocab_size(self) -> int: - return self._vocab_size - - @property - def encoder(self): - return self._token_to_id - - @property - def decoder(self): - return self._id_to_token diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py index 24be1b227d..84496bc4a3 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py @@ -39,18 +39,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers import GenerationConfig -from transformers.generation.utils import GenerationMixin -from transformers.modeling_utils import PreTrainedModel -from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) +from transformers.utils import is_flash_attn_greater_or_equal_2_10, logging from .block_config import AttentionConfig, FFNConfig, MambaConfig, MoEConfig from .configuration_decilm import DeciLMConfig @@ -61,15 +50,6 @@ from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import ( _flash_attention_forward, ) -from .transformers_4_44_2__modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS from .transformers_4_51_3__modeling_llama4_attention import Llama4TextAttention, Llama4TextConfig @@ -77,7 +57,6 @@ from .vllm_yarn_utils import YaRNScalingRotaryEmbedding # from transformers.models.llama4.modeling_llama4 import Llama4TextL2Norm -MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[DeciLMConfig.model_type] = "DeciLMForCausalLM" logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeciLMConfig" @@ -1588,673 +1567,6 @@ def forward( return outputs -DECILM_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`DeciLMConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare DeciLM Model outputting raw hidden-states without any specific head on top.", - DECILM_START_DOCSTRING, -) -class DeciLMPreTrainedModel(PreTrainedModel): - config_class = DeciLMConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["DeciLMDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True # all the _supports_... flags refer to the Llama3 layers - _supports_sdpa = False - _supports_flex_attn = False - _supports_cache_class = True - _supports_quantized_cache = False - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _prepare_generation_config( - self, - generation_config: GenerationConfig | None, - *args, - **kwargs, - ) -> tuple[GenerationConfig, dict]: - try: - from transformers import cache_utils - from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING - - need_setup_cache_classes_mapping = NEED_SETUP_CACHE_CLASSES_MAPPING - except Exception: - # older releases exposed it via generation.utils - need_setup_cache_classes_mapping = {} - - # DeciLM-specific code - generation_config, model_kwargs = super()._prepare_generation_config( - generation_config, *args, **kwargs - ) - # New transformers version, can reach only through cache_utils - if need_setup_cache_classes_mapping == {}: - cache_utils._CACHE_IMPLEMENTATION_MAPPING["variable"] = VariableCache - else: - need_setup_cache_classes_mapping["variable"] = VariableCache - - generation_config.cache_implementation = "variable" - return generation_config, model_kwargs - - -DECILM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`VariableCache`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - If passed to the forward function, past_key_values must be a VariableCache object (see imports). - For generation purposes, this is already handled inside model.generate(). - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare DeciLM Model outputting raw hidden-states without any specific head on top.", - DECILM_START_DOCSTRING, -) -class DeciLMModel(DeciLMPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`] - - Args: - config: DeciLMConfig - """ - - def __init__(self, config: DeciLMConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [ - ( - DeciLMDecoderLayer(config, layer_idx) - if (config.block_configs[layer_idx].parallel_blocks is None) - else DeciLMMultiDecoderLayer(config, layer_idx) - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) - self.norm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - if self.config.position_embedding_type in ["rope", "rope_llama4", "mistral_yarn"]: - self.rotary_emb = rope_type_to_class[self.config.position_embedding_type](config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def get_final_layer_norm(self): - return self.norm - - def set_final_layer_norm(self, value): - self.norm = value - - @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_router_logits: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, - ) -> tuple | BaseModelOutputWithPast: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - is_legacy_cache_format = (past_key_values is not None) and type( - past_key_values - ).__name__ != "VariableCache" - # We use the __name__ instead of isinstance to support weird use cases - # (init cache from a checkpoint dir and use it with local code) - if is_legacy_cache_format: - raise NotImplementedError( - "DeciLMModel does not support legacy cache format, please use a newer " - "transformers version or use VariableCache explicitly (see import in this file)." - ) - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - # use default device - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = None - if hasattr(self, "rotary_emb"): - # rotary emb is created all devices, so we need to move position_ids to the correct device - some_param = next(self.parameters()) - position_ids = position_ids.to(some_param.device) - cache_position = cache_position.to(some_param.device) - faux_hidden_states = position_ids.to(some_param.dtype) - position_embeddings = self.rotary_emb(faux_hidden_states, position_ids) - # print(f'START {position_embeddings.device=}') # HF hook will change the device - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - if self.config.block_return_only_hidden_states: - hidden_states = layer_outputs - next_decoder_cache = past_key_values - - else: - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - # Extract router logits if they exist - if output_router_logits: - router_logits_index = -1 # Router logits are always the last element - if len(layer_outputs) > (2 if output_attentions else 1) + ( - 1 if use_cache else 0 - ): - all_router_logits += (layer_outputs[router_logits_index],) - - # Final layer norm - hidden_states = hidden_states.to(next(self.parameters()).device) - hidden_states = self.norm(hidden_states) - - # Add the last hidden state - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # Set the next cache - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - outputs = (hidden_states, next_cache, all_hidden_states, all_self_attns) - if output_router_logits: - outputs += (all_router_logits,) - return outputs - - # Handle different return types based on whether router logits are requested - if output_router_logits and all_router_logits: - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) - else: - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -@add_start_docstrings( - """ - The DeciLM Model transformer with a sequence classification head on top (linear layer). - - [`DeciLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - DECILM_START_DOCSTRING, -) -class DeciLMForSequenceClassification(DeciLMPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = DeciLMModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> tuple | SequenceClassifierOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - elif input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits, *transformer_outputs[1:]) - return (loss, *output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ -The DeciLM Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - DECILM_START_DOCSTRING, -) -class DeciLMForQuestionAnswering(DeciLMPreTrainedModel): - base_model_prefix = "transformer" - - # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->DeciLM - def __init__(self, config): - super().__init__(config) - self.transformer = DeciLMModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.transformer.embed_tokens - - def set_input_embeddings(self, value): - self.transformer.embed_tokens = value - - @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.FloatTensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - start_positions: torch.LongTensor | None = None, - end_positions: torch.LongTensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> tuple | QuestionAnsweringModelOutput: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1).to(start_logits.device) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1).to(end_logits.device) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits, *outputs[2:]) - return (total_loss, *output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - The DeciLM Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - DECILM_START_DOCSTRING, -) -class DeciLMForTokenClassification(DeciLMPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = DeciLMModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> tuple | TokenClassifierOutput: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits, *outputs[2:]) - return (loss, *output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - ######################################################################## # DeciLM-specific code ######################################################################## @@ -2430,201 +1742,3 @@ class LMHead(nn.Linear): """ Special class to allow FSDP wrapping without affecting other Linear layers in the model. """ - - -class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: DeciLMConfig): - super().__init__(config) - self.model = DeciLMModel(config) - self.vocab_size = config.vocab_size - self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def compute_router_aux_loss(self, router_logits): - """ - Computes the auxiliary loss for router logits. - This encourages load balancing across experts. - - Args: - router_logits: List of router logits tensors from each MoE layer - Each tensor has shape [batch_size, sequence_length, num_experts] - - Returns: - Auxiliary loss tensor - """ - aux_loss = torch.tensor(0.0, device=router_logits[0].device) - - for layer_idx, layer_router_logits in enumerate(router_logits): - router_probs = torch.softmax(layer_router_logits, dim=-1) - - # Mean routing probability across batch and sequence dimensions - mean_prob = router_probs.mean(dim=[0, 1]) - - # Compute auxiliary loss: combination of load balancing and importance loss - # Load balancing loss: variance of expert usage probabilities (should be uniform) - num_experts = mean_prob.size(0) - ideal_prob = 1.0 / num_experts - balance_loss = torch.sum((mean_prob - ideal_prob) ** 2) - - # Add this layer's auxiliary loss to the total - aux_loss = aux_loss + balance_loss - - # Average over all layers - if len(router_logits) > 0: - aux_loss = aux_loss / len(router_logits) - - return aux_loss - - @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - output_router_logits: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Return: - """ - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - return_dict=return_dict, - cache_position=cache_position, - ) - - # Extract model outputs based on return type - if isinstance(outputs, MoeModelOutputWithPast): - hidden_states = outputs.last_hidden_state - router_logits = outputs.router_logits - elif return_dict: - hidden_states = outputs.last_hidden_state - router_logits = None # No router logits in this case - else: - hidden_states = outputs[0] - router_logits = outputs[4] if output_router_logits and len(outputs) > 4 else None - - # Generate logits - logits = self.lm_head(hidden_states) - logits = logits.float() - - # Calculate loss if labels are provided - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - # Calculate router aux loss if router logits are present - if router_logits is not None and self.config.router_aux_loss_coef > 0: - aux_loss = self.compute_router_aux_loss(router_logits) - loss = loss + aux_loss * self.config.router_aux_loss_coef - - # Handle non-dict return - if not return_dict: - output = (logits,) - if isinstance(outputs, tuple): - output += outputs[1:] # Add all other outputs - return (loss, *output) if loss is not None else output - - # Different output types for MoE vs regular model - if router_logits is not None: - return MoeCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values if return_dict else outputs[1], - hidden_states=outputs.hidden_states - if return_dict - else outputs[2] - if output_hidden_states - else None, - attentions=outputs.attentions - if return_dict - else outputs[3] - if output_attentions - else None, - router_logits=router_logits, - ) - else: - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values if return_dict else outputs[1], - hidden_states=outputs.hidden_states - if return_dict - else outputs[2] - if output_hidden_states - else None, - attentions=outputs.attentions - if return_dict - else outputs[3] - if output_attentions - else None, - ) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/tokenization_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/tokenization_decilm.py deleted file mode 100644 index 14c840b8b1..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/tokenization_decilm.py +++ /dev/null @@ -1,195 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors - -""" -Only needed for DeciLM models that use Megatron tokenizers. -DeciLM models that use Llama tokenizers do not need external code. -""" - -import json -import os -from pathlib import Path -from typing import Literal - -from transformers import PreTrainedTokenizer -from transformers.dynamic_module_utils import custom_object_save -from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE, AddedToken - -from .megatron_lm__megatron_tokenizer import ( - MegatronTokenizer, # fake import to make AutoTokenizer infer the dependency -) -from .megatron_lm__tokenizer import PATTERN_TIKTOKEN, PATTERN_TIKTOKEN_V2, CustomTikTokenizer - -MegatronTokenizer # make sure that auto-formatting doesn't remove the import - - -class MegatronTikTokenizer(PreTrainedTokenizer): - vocab_files_names: dict[str, str] = {"vocab_file": "tiktoken_vocab.json"} - model_input_names: list[str] = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file: str, - tiktoken_pattern: Literal["v1", "v2"], - vocab_size: int, - tiktoken_num_special_tokens: int, - tiktoken_special_tokens: list[str] | None, - add_bos_token: bool = False, # nm5 does not use bos token - add_eos_token: bool = False, # nm5 does not use eos token - **unused_kwargs, - ): - assert "chat_template" not in unused_kwargs, ( - "We enforce the Nemotron5 chat template from the code, " - "please do not provide a chat_template in the tokenizer_config.json file" - ) - - pattern = PATTERN_TIKTOKEN if tiktoken_pattern == "v1" else PATTERN_TIKTOKEN_V2 - self._tokenizer = CustomTikTokenizer( - path=vocab_file, - pattern=pattern, - vocab_size=vocab_size, - num_special_tokens=tiktoken_num_special_tokens, - special_tokens=tiktoken_special_tokens, - ) - - eos_token = self._tokenizer.detokenize([self._tokenizer.eos]) - bos_token = self._tokenizer.detokenize([self._tokenizer.bos]) - self.vocab = self._tokenizer.vocab - super().__init__( - eos_token=AddedToken(eos_token, normalized=False, special=True), - bos_token=AddedToken(bos_token, normalized=False, special=True), - pad_token=AddedToken(eos_token, normalized=False, special=True), - ) - - self.add_bos_token = add_bos_token - self.add_eos_token = add_eos_token - self.chat_template = NEMOTRON5_CHAT_TEMPLATE - - self._vocab_file_contents = Path(vocab_file).read_text() - self._tokenizer_config = { - "tiktoken_pattern": tiktoken_pattern, - "vocab_size": vocab_size, - "tiktoken_num_special_tokens": tiktoken_num_special_tokens, - "tiktoken_special_tokens": tiktoken_special_tokens, - "add_bos_token": add_bos_token, - "add_eos_token": add_eos_token, - "tokenizer_class": "MegatronTikTokenizer", - "auto_map": { - "AutoTokenizer": ["tokenization_decilm.MegatronTikTokenizer", None], - }, - } - - def get_vocab(self) -> dict[str, int]: - """to satisfy PreTrainedTokenizer.__init__()""" - return self.vocab - - def tokenize(self, text: str, **kwargs) -> list[str]: - return [text] - - def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: - is_single_token = isinstance(tokens, str) - if is_single_token: - text = tokens - else: - assert len(tokens) == 1 - text = tokens[0] - - ids = self._tokenizer._model.encode(text, allowed_special="all") - - if is_single_token: - assert len(ids) == 1, ( - f"Asked to convert a single token to its id, but it's not a single token: encode('{tokens}') = {ids}" - ) - return ids[0] - else: - return ids - - def convert_ids_to_tokens( - self, ids: int | list[int], skip_special_tokens: bool = False - ) -> str | list[str]: - is_single_id = isinstance(ids, int) - if is_single_id: - ids = [ids] - - if skip_special_tokens: - ids = [idd for idd in ids if idd not in (self.eos_token_id, self.bos_token_id)] - - text = self._tokenizer.detokenize(ids) - - if is_single_id: - return text - else: - return [text] - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - """Taken from LlamaTokenizer""" - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = bos_token_id + token_ids_0 + eos_token_id - - if token_ids_1 is not None: - output = output + bos_token_id + token_ids_1 + eos_token_id - - return output - - def save_pretrained( - self, - save_directory: str | os.PathLike, - legacy_format: bool | None = None, - filename_prefix: str | None = None, - push_to_hub: bool = False, - **kwargs, - ) -> tuple[str, ...]: - assert legacy_format is None, "Unsupported" - assert filename_prefix is None, "Unsupported" - assert not push_to_hub, "Unsupported" - - save_directory = Path(save_directory) - save_directory.mkdir(parents=True, exist_ok=True) - - tokenizer_config_path = save_directory / TOKENIZER_CONFIG_FILE - tokenizer_config_path.write_text(json.dumps(self._tokenizer_config, indent=2)) - - vocab_files_name = self.vocab_files_names["vocab_file"] - vocab_file_path = save_directory / vocab_files_name - vocab_file_path.write_text(self._vocab_file_contents) - - custom_object_save(self, save_directory) - - return str(tokenizer_config_path), str(vocab_file_path) - - -NEMOTRON5_CHAT_TEMPLATE = """{% if messages[0].role != "system" %} - {% set messages = [{"role": "system", "content": ""}] + messages %} -{% endif %} -{% for message in messages %} - {% if message.role == "system" %} -System -{{ message.content }} - {% elif message.role == "user" %} -User -{{ message.content }} - {% elif message.role == "assistant" %} -Assistant -{{ message.content }} - {% endif %} -{% endfor %} -{% if add_generation_prompt %} -Assistant -{% else %} - -{% endif %}""" diff --git a/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py b/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py deleted file mode 100644 index d4292322f7..0000000000 --- a/modelopt/torch/puzzletron/export/MCore/llama_nemotron.py +++ /dev/null @@ -1,1015 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# based on https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/llm/gpt/model/llama_nemotron.py - -import json -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Callable, Dict, Optional, Union - -import torch -import torch.nn.functional as F -from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel, torch_dtype_from_mcore_config -from nemo.collections.llm.gpt.model.llama import ( - Llama3Config, - Llama31Config, - Llama31Config70B, - LlamaConfig, - apply_rope_scaling, -) -from nemo.collections.llm.utils import Config -from nemo.lightning import OptimizerModule, io, teardown -from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME -from nemo.lightning.io.pl import ckpt_to_weights_subdir -from nemo.lightning.io.state import TransformFns -from nemo.lightning.pytorch.utils import dtype_from_hf, dtype_from_str -from nemo.utils import logging -from nemo.utils.import_utils import safe_import -from torch import nn - -from modelopt.torch.puzzletron.tools.logger import mprint - -# from nemo.collections.llm.gpt.model.llama_nemotron import Llama33NemotronSuper49BConfig - - -_, HAVE_TE = safe_import("transformer_engine") -from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( - get_gpt_heterogeneous_layer_spec, -) -from megatron.core.transformer.heterogeneous.heterogeneous_config import ( - HeterogeneousTransformerConfig, -) -from megatron.core.transformer.spec_utils import ModuleSpec - -if TYPE_CHECKING: - from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel - from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer - from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec - from peft import AutoPeftModelForCausalLM, PeftConfig - from transformers import GenerationConfig, LlamaForCausalLM - from transformers import LlamaConfig as HFLlamaConfig - - from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig - -from modelopt.torch.puzzletron.export.MCore.llama_nemotron_utils import ( - _build_puzzletron_mappings_and_transforms, - _config_to_dict, - convert_attention_config_from_cfg_object, - convert_mlp_config_from_cfg_object, - convert_nemo_config_to_hf_decilm_config, - dtype_from_dict, - merge_qkv_for_puzzletron, - split_qkv_for_puzzletron, -) -from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( - PuzzletronHeterogeneousTransformerConfig, - get_gpt_heterogeneous_layer_spec_puzzletron, -) - - -def heterogeneous_layer_spec_puzzletron( - config: PuzzletronHeterogeneousTransformerConfig, -) -> ModuleSpec: - return get_gpt_heterogeneous_layer_spec_puzzletron(config, use_transformer_engine=HAVE_TE) - - -# Refactored to inherit directly from GPTConfig instead of Llama31Config70B -# This makes it easier to understand what attributes are set through the hierarchy -@dataclass -class PuzzletronNemotronModelConfig(GPTConfig, PuzzletronHeterogeneousTransformerConfig): - """Configuration for Puzzletron Nemotron models. - - DESIGN RATIONALE: - ================ - Refactored from original inheritance (Llama31Config70B + PuzzletronHeterogeneousTransformerConfig) - to explicit attribute definition for clarity and maintainability. Maintains identical behavior - to the original Llama hierarchy while enabling future flexibility. - - ATTRIBUTE ORGANIZATION: - ====================== - Explicitly defines attributes from the Llama hierarchy: - Llama31Config70B → Llama31Config → Llama3Config → LlamaConfig → GPTConfig - - FUTURE DEVELOPMENT: - ================== - Attributes can be freely modified/removed for future Puzzletron models. - In this case the tests in test_puzzletron_nemotron_config_inheritance.py will need to be updated. - Current explicit definition is for clarity during transition period. - """ - - # Override attributes from PuzzletronHeterogeneousTransformerConfig with Llama hierarchy values - # These ensure we maintain the same behavior as the original Llama31Config70B inheritance - - # ===== LlamaConfig attributes ===== - # Core model architecture - # NOTE: Default is F.silu, but this is overridden during instantiation to match all blocks - # See instantiate_nemo_config_from_adapted_dict() which enforces same activation across blocks - activation_func: Callable = F.silu - normalization: str = "RMSNorm" - gated_linear_unit: bool = True - position_embedding_type: str = "rope" - add_bias_linear: bool = False - # seq_length: int = 4096 # (will be overridden by Llama31Config70B) - attention_dropout: float = 0.0 - hidden_dropout: float = 0.0 - share_embeddings_and_output_weights: bool = False - # Fusion settings - bias_activation_fusion: bool = True - masked_softmax_fusion: bool = True - persist_layer_norm: bool = True - bias_dropout_fusion: bool = True - apply_rope_fusion: bool = True - use_transformer_engine_op_fuser: Optional[bool] = None - - # ===== Llama3Config attributes ===== - num_query_groups: int = 8 - # init_method_std: float = 0.01 # (will be overridden by Llama31Config) - layernorm_epsilon: float = 1.0e-05 - rotary_percent: float = 1.0 - - # ===== Llama31Config attributes ===== - scale_factor: float = 8.0 - low_freq_factor: float = 1.0 - high_freq_factor: float = 4.0 - old_context_len: int = 8192 - init_method_std: float = 0.02 # (overrides Llama3Config) - - # ===== Llama31Config70B attributes ===== - # Core model architecture (70B-specific) - rotary_base: int = 500_000 - seq_length: int = 131072 # (overrides LlamaConfig) - num_layers: int = 80 # - hidden_size: int = 8192 # - ffn_hidden_size: int = 28672 # - num_attention_heads: int = 64 # - kv_channels: int = 128 # (derived from hidden_size // num_attention_heads) - make_vocab_size_divisible_by: int = 128 # - - # ===== PuzzletronHeterogeneousTransformerConfig attributes ===== - # Actual new PuzzleNemotronModelConfig attributes - heterogeneous_layers_config_path: Optional[str] = None - heterogeneous_layers_config_encoded_json: Optional[str] = None - transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = ( - heterogeneous_layer_spec_puzzletron - ) - - # HF-specific metadata for lossless round-trip conversion (HF → NeMo → HF) - # Stores HF config fields that don't have direct NeMo equivalents - source_hf_config_metadata: Optional[Dict[str, Any]] = None - - # NOTE: How activation_func is handled for Puzzletron models - # ============================================================== - # Puzzletron models can define activation functions per-block, but MCore's validation - # only checks the global activation_func (not per-block activations). - # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_config.py#L1043-L1061 - # - # Current approach (enforced in instantiate_nemo_config_from_adapted_dict): - # - All blocks must use the SAME activation function (None allowed for no-op blocks) - # - The global activation_func is set to match the blocks' shared activation - # - This ensures MCore's global validation passes correctly - # - # Rationale: - # 1. MCore validates global activation_func during __post_init__() (lines 1043-1061) - # 2. NeMo calls __post_init__() AGAIN during trainer.strategy.connect(model) - # See: https://github.com/NVIDIA/NeMo/blob/2e19aebd8c8fa9ff7ce9b5076ce130404713443c/nemo/lightning/_strategy_lib.py#L172-L175 - # 3. At runtime, MCore uses per-block activations from get_config_for_layer() - # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_block.py#L308-L319 - # - # For heterogeneous activations across blocks, MCore would need to update their - # validation logic to support per-block validation (e.g., in get_config_for_layer() or MLP.__init__) - - # ===== Llama31Config method ===== - def configure_model( - self, tokenizer, pre_process=None, post_process=None, vp_stage=None - ) -> "MCoreGPTModel": - """Configure and instantiate a Megatron Core Llama 3.1 model. - - NOTE: This method is originally from Llama31Config and is explicitly included here - for consistency and clarity. It maintains the same behavior as the original - Llama hierarchy inheritance approach. - - Extends the base configuration with Llama 3.1 specific RoPE scaling. - This method applies RoPE scaling for extended context length support. - """ - model = super().configure_model(tokenizer, pre_process, post_process, vp_stage) - # Apply rope scaling for Llama3.1 model - model.rotary_pos_emb.inv_freq = apply_rope_scaling( - model.rotary_pos_emb.inv_freq, - factor=self.scale_factor, - low_freq_factor=self.low_freq_factor, - high_freq_factor=self.high_freq_factor, - old_context_len=self.old_context_len, - ) - return model - - @classmethod - def from_dict_with_preprocessing(cls, config_dict): - # Potentially adapt the config_dict before instantiation - instance = cls(**config_dict) - # Potentially adapt the config after instantiation - return instance - - # static method - @staticmethod - def create_adapted_config_dict_from_puzzletron_config(cfg): - # TODO: consider doing do this without conversion to dictionary in the future (instead have an adapted config object) - # Create an empty config object of the same class as cfg - adapted_cfg_dict = dict() - orig_cfg_dict = vars(cfg) - - # Extract first set of values from the original config - adapted_cfg_dict["head_dim"] = orig_cfg_dict["head_dim"] - adapted_cfg_dict["num_attention_heads"] = orig_cfg_dict["num_attention_heads"] - # Handle rope_scaling - can be None, missing, or a dict - adapted_cfg_dict["rope_scaling"] = orig_cfg_dict.get("rope_scaling") or {} - - block_conf = { - "block_configs": [ - { - "attention": convert_attention_config_from_cfg_object( - orig_cfg_dict["block_configs"][i].attention, - adapted_cfg_dict["num_attention_heads"], - adapted_cfg_dict["head_dim"], - ), - "mlp": { - **convert_mlp_config_from_cfg_object( - orig_cfg_dict["block_configs"][i].ffn, - ( - orig_cfg_dict["block_configs"][i].parallel_blocks - if hasattr(orig_cfg_dict["block_configs"][i], "parallel_blocks") - else None - ), - ), - # Store the per-block activation function as a string (for JSON serialization) - "hidden_act": ( - orig_cfg_dict["block_configs"][i].ffn.hidden_act - if not ( - orig_cfg_dict["block_configs"][i].ffn.no_op - or orig_cfg_dict["block_configs"][i].ffn.replace_with_linear - ) - else None - ), - }, - } - for i in range(len(orig_cfg_dict["block_configs"])) - ] - } - if orig_cfg_dict["o_proj_bias"] != orig_cfg_dict["attention_bias"]: - raise NotImplementedError("o_proj_bias is not fully supported") - if orig_cfg_dict["position_embedding_type"] not in ["rope", "yarn"]: - # this one is not supported by MCore - raise ValueError( - f"only rope and yarn are supported, got {orig_cfg_dict['position_embedding_type']}" - ) - - # Handle dtype (new format uses 'dtype', old format uses 'torch_dtype') - # Check 'dtype' first, then fall back to 'torch_dtype' - if "dtype" in orig_cfg_dict and orig_cfg_dict["dtype"] is not None: - mprint(f"DEBUG: dtype found in config: {orig_cfg_dict['dtype']}") - adapted_cfg_dict["torch_dtype"] = orig_cfg_dict["dtype"] - elif "torch_dtype" in orig_cfg_dict and orig_cfg_dict["torch_dtype"] is not None: - mprint(f"DEBUG: torch_dtype found in config: {orig_cfg_dict['torch_dtype']}") - adapted_cfg_dict["torch_dtype"] = orig_cfg_dict["torch_dtype"] - else: - mprint( - f"WARNING: neither dtype nor torch_dtype found in config (or both are None), setting to bfloat16" - ) - adapted_cfg_dict["torch_dtype"] = "bfloat16" - - # TODO: check how config keys such as position_embedding_type are handled (since they're not passed to the constructor) - adapted_cfg_dict["heterogeneous_layers_config_path"] = None - adapted_cfg_dict["block_configs"] = block_conf["block_configs"] - adapted_cfg_dict["heterogeneous_layers_config_encoded_json"] = json.dumps( - block_conf, ensure_ascii=False - ) - adapted_cfg_dict["transformer_layer_spec"] = heterogeneous_layer_spec_puzzletron - adapted_cfg_dict["vocab_size"] = orig_cfg_dict["vocab_size"] - adapted_cfg_dict["num_layers"] = len(orig_cfg_dict["block_configs"]) - adapted_cfg_dict["hidden_size"] = orig_cfg_dict["hidden_size"] - # adapted_cfg_dict['num_attention_heads'] = cfg["num_attention_heads"] - adapted_cfg_dict["kv_channels"] = adapted_cfg_dict["head_dim"] - adapted_cfg_dict["scale_factor"] = float( - adapted_cfg_dict["rope_scaling"].get("factor", 8.0) - ) - adapted_cfg_dict["rotary_base"] = int(orig_cfg_dict.get("rope_theta", 500_000)) - adapted_cfg_dict["seq_length"] = int(orig_cfg_dict.get("max_position_embeddings", 131072)) - adapted_cfg_dict["init_method_std"] = float(orig_cfg_dict.get("initializer_range", 0.02)) - adapted_cfg_dict["layernorm_epsilon"] = float(orig_cfg_dict.get("rms_norm_eps", 1e-5)) - adapted_cfg_dict["share_embeddings_and_output_weights"] = bool( - orig_cfg_dict.get("tie_word_embeddings", False) - ) - # adapted_cfg_dict["make_vocab_size_divisible_by"] = 128 - - # Preserve HF-specific config fields that don't have NeMo equivalents - # This enables lossless round-trip conversion HF → NeMo → HF - source_hf_config_metadata = {} - - # eos_token_id: HF can have multiple EOS tokens [128001, 128008, 128009] - # but NeMo tokenizer only supports single eos_id (uses the last one) - if "eos_token_id" in orig_cfg_dict: - source_hf_config_metadata["eos_token_id"] = orig_cfg_dict["eos_token_id"] - - # auto_map: HF-specific field for custom model class loading via trust_remote_code - # Not relevant to NeMo but needed for HF model.from_pretrained() to work - if "auto_map" in orig_cfg_dict: - source_hf_config_metadata["auto_map"] = orig_cfg_dict["auto_map"] - - # dtype: HF uses 'dtype' field, NeMo uses 'torch_dtype', preserve both - if "dtype" in orig_cfg_dict: - source_hf_config_metadata["dtype"] = orig_cfg_dict["dtype"] - - # Store as direct config attribute (will be serialized by NeMo automatically) - adapted_cfg_dict["source_hf_config_metadata"] = ( - source_hf_config_metadata if source_hf_config_metadata else None - ) - - return adapted_cfg_dict - - -class PuzzletronLlamaNemotronModel(GPTModel): - """Llama-Nemotron model implementation based on the GPT model architecture. - - This class provides a high-level interface for Llama-Nemotron models, - implementing the specific architecture and settings needed for Llama-Nemotron models. - """ - - def __init__( - self, - config: Annotated[ - Optional[PuzzletronNemotronModelConfig] | type[PuzzletronNemotronModelConfig], - Config[PuzzletronNemotronModelConfig], - ] = None, - optim: Optional[OptimizerModule] = None, - tokenizer: Optional["TokenizerSpec"] = None, - model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, - ): - super().__init__( - config or PuzzletronNemotronModelConfig(), - optim=optim, - tokenizer=tokenizer, - model_transform=model_transform, - ) - - -def instantiate_nemo_config_from_adapted_dict( - adapted_cfg_dict: dict, - generation_config: Optional["GenerationConfig"] = None, -) -> PuzzletronNemotronModelConfig: - """ - Instantiate PuzzletronNemotronModelConfig from adapted config dict. - - This function is shared by the importer and tests to ensure consistency. - - Args: - adapted_cfg_dict: Dict created by create_adapted_config_dict_from_puzzletron_config - generation_config: Optional generation config to attach - - Returns: - PuzzletronNemotronModelConfig instance - """ - - # Helper function for vocab size divisibility - def make_vocab_size_divisible_by(vocab_size: int) -> int: - base = 128 - while vocab_size % base != 0: - base //= 2 - return base - - # Keys used for PuzzletronNemotronModelConfig instantiation - INSTANTIATION_KEYS = { - "heterogeneous_layers_config_encoded_json", - "transformer_layer_spec", - "num_layers", - "hidden_size", - "num_attention_heads", - "kv_channels", - "scale_factor", - "init_method_std", - "layernorm_epsilon", - "seq_length", - "rotary_base", - "vocab_size", - "share_embeddings_and_output_weights", - "source_hf_config_metadata", - } - - # Keys that are metadata or derived (not directly passed to constructor) - metadata_keys = set(adapted_cfg_dict.keys()) - INSTANTIATION_KEYS - - mprint(f"DEBUG: Keys used for instantiation: {sorted(INSTANTIATION_KEYS)}") - mprint(f"DEBUG: Metadata keys (not used for direct instantiation): {sorted(metadata_keys)}") - for key in sorted(metadata_keys): - value = adapted_cfg_dict[key] - if isinstance(value, (list, dict)): - mprint(f" - {key}: {type(value).__name__} with {len(value)} items") - elif callable(value): - mprint(f" - {key}: {value.__name__ if hasattr(value, '__name__') else 'callable'}") - else: - mprint(f" - {key}: {value}") - - model_dtype = dtype_from_dict(adapted_cfg_dict) - - # Determine the unique activation_func from all blocks - # MCore validates the global activation_func, so we need to set it to match all blocks - heterogeneous_config = json.loads(adapted_cfg_dict["heterogeneous_layers_config_encoded_json"]) - block_list = heterogeneous_config.get("block_configs", []) - - # Assert that block_configs exists and is not empty - assert block_list, ( - "No block_configs found in heterogeneous_layers_config_encoded_json. " - "The JSON structure must contain a 'block_configs' list with at least one block." - ) - - activation_funcs = [] - - for i, block in enumerate(block_list): - # Extract hidden_act from MLP config (if present) - if "mlp" in block and "hidden_act" in block["mlp"]: - hidden_act_str = block["mlp"]["hidden_act"] - - # Track None/null values (used for no-op blocks) - if hidden_act_str is None: - activation_funcs.append(None) - continue - - # For now, only support silu and gelu activations - # See: https://github.com/NVIDIA/Megatron-LM/blob/268fda08592528b7bc1a21aadaed259980ca8efb/megatron/core/transformer/transformer_config.py#L1043-L1048 - if hidden_act_str == "silu": - activation_funcs.append(F.silu) - elif hidden_act_str == "gelu": - activation_funcs.append(F.gelu) - else: - raise NotImplementedError( - f"Unsupported activation function: '{hidden_act_str}' in block {i}. " - f"Only 'silu', 'gelu', and None/null are currently supported. " - f"MCore's bias_activation_fusion only validates these activation functions." - ) - # If no hidden_act key or no MLP, we treat it as None - else: - activation_funcs.append(None) - - # Separate None and not-None activations - not_none_activations = [f for f in activation_funcs if f is not None] - - # Check that all not-None activation functions are the same - unique_not_none = {id(f) for f in not_none_activations} - - if len(unique_not_none) == 0: - # No activation functions found (all blocks are no-op or have None) - # Default to F.silu to pass MCore validation - global_activation_func = F.silu - mprint( - "WARNING: No not-None activation functions found in blocks, defaulting global activation_func to F.silu" - ) - elif len(unique_not_none) == 1: - # All not-None blocks use the same activation function (safe) - global_activation_func = not_none_activations[0] - func_name = ( - global_activation_func.__name__ - if hasattr(global_activation_func, "__name__") - else str(global_activation_func) - ) - none_count = activation_funcs.count(None) - total_count = len(activation_funcs) - mprint( - f"INFO: All {total_count - none_count} not-None blocks use the same activation function: {func_name} ({none_count} None/no-op blocks)" - ) - else: - # Multiple different not-None activation functions found (currently not supported/tested) - func_names = [f.__name__ if hasattr(f, "__name__") else "None" for f in activation_funcs] - unique_func_names = set(f.__name__ for f in not_none_activations) - assert False, ( - f"Puzzletron blocks must all use the same activation function (None allowed for no-op blocks). " - f"Found {len(unique_not_none)} different not-None activation functions across blocks: {unique_func_names}. " - f"Block activations: {func_names}. " - f"MCore's validation only checks the global activation_func, which would not match heterogeneous activations. " - f"Either make all blocks use the same activation, or update MCore to support per-block validation." - ) - - return PuzzletronNemotronModelConfig( - heterogeneous_layers_config_encoded_json=adapted_cfg_dict[ - "heterogeneous_layers_config_encoded_json" - ], - heterogeneous_layers_config_path=None, # We directly load the block config as json - transformer_layer_spec=adapted_cfg_dict["transformer_layer_spec"], - activation_func=global_activation_func, # Set to match all blocks - num_layers=adapted_cfg_dict["num_layers"], - hidden_size=adapted_cfg_dict["hidden_size"], - num_attention_heads=adapted_cfg_dict["num_attention_heads"], - kv_channels=adapted_cfg_dict["kv_channels"], - scale_factor=adapted_cfg_dict["scale_factor"], - init_method_std=adapted_cfg_dict["init_method_std"], - layernorm_epsilon=adapted_cfg_dict["layernorm_epsilon"], - seq_length=adapted_cfg_dict["seq_length"], - rotary_base=adapted_cfg_dict["rotary_base"], - make_vocab_size_divisible_by=make_vocab_size_divisible_by(adapted_cfg_dict["vocab_size"]), - vocab_size=adapted_cfg_dict["vocab_size"], - share_embeddings_and_output_weights=adapted_cfg_dict["share_embeddings_and_output_weights"], - # HF-specific metadata for lossless round-trip conversion - source_hf_config_metadata=adapted_cfg_dict.get("source_hf_config_metadata"), - fp16=(model_dtype == torch.float16), - bf16=(model_dtype == torch.bfloat16), - params_dtype=model_dtype, - generation_config=generation_config, - ) - - -@io.model_importer(PuzzletronLlamaNemotronModel, "hf") -class PuzzletronHFLlamaNemotronImporter( - io.ModelConnector["LlamaForCausalLM", PuzzletronLlamaNemotronModel] -): - """Importer for converting Hugging Face Llama-Nemotron models to NeMo format. - - This class handles the conversion of Hugging Face's LlamaForCausalLM models - to NeMo's PuzzletronLlamaNemotronModel format, including weight mapping and configuration translation. - """ - - # Base mapping using standard LLaMA weight names - # Layernorm wildcards are replaced with per-layer mappings in convert_state() - # TODO: MoE and Mamba layer conversions have not been tested yet - default_mapping = { - "model.embed_tokens.weight": "embedding.word_embeddings.weight", - "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", - "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", - "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", - "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", - "model.norm.weight": "decoder.final_layernorm.weight", - "lm_head.weight": "output_layer.weight", - } - - def init(self) -> PuzzletronLlamaNemotronModel: - """Initialize a NeMo LlamaModel instance. - - Returns: - LlamaModel: Initialized NeMo Llama model with the appropriate configuration - and tokenizer. - """ - config = self.config - mprint(f"DEBUG: NeMo config dtype settings:") - mprint(f" - config.bf16: {config.bf16}") - mprint(f" - config.fp16: {config.fp16}") - mprint(f" - config.params_dtype: {config.params_dtype}") - return PuzzletronLlamaNemotronModel(config, tokenizer=self.tokenizer) - - def apply(self, output_path: Path) -> Path: - """Apply the conversion from HF to NeMo format. - - Args: - output_path: Path where the converted model will be saved - - Returns: - Path: Path to the saved NeMo model - """ - from transformers import AutoModelForCausalLM - - logging.info(f"Load Puzzletron HF model {str(self)}") - source = AutoModelForCausalLM.from_pretrained( - str(self), trust_remote_code=True, torch_dtype="auto" - ) - logging.info("Initialize NeMo Puzzletron Llama Nemotron model") - target = self.init() - trainer = self.nemo_setup(target) - self.convert_state(source, target) - self.nemo_save(output_path, trainer) - - mprint( - f"Converted Llama-Nemotron model to Nemo, model saved to {output_path} in {source.dtype}." - ) - - teardown(trainer, target) - del trainer, target - - return output_path - - def convert_state(self, source: Any, target: Any) -> Any: - """Convert state dict from HF format to NeMo format. - - Maps the weights from the HF model to the NeMo model according to - the appropriate mapping scheme. - - Args: - source: Source HF model - target: Target NeMo model - - Returns: - The result of applying the transforms - """ - mapping = self.default_mapping.copy() - - if target.config.normalization == "LayerNorm": - mapping["model.norm.bias"] = "decoder.final_layernorm.bias" - if getattr(source.config, "tie_word_embeddings", False): - del mapping["lm_head.weight"] - - # Puzzletron models must have block_configs for heterogeneous layer support - assert hasattr(source.config, "block_configs"), "Puzzletron models must have block_configs" - - # Build per-layer specific mappings for heterogeneous support - attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs = ( - _build_puzzletron_mappings_and_transforms(source.config) - ) - - # Remove layernorm wildcards from default_mapping - these will be replaced with - # specific per-layer mappings based on each layer's architecture. - for pattern in [ - "model.layers.*.input_layernorm.weight", - "model.layers.*.post_attention_layernorm.weight", - ]: - if pattern in mapping: - del mapping[pattern] - - # Add all layer-specific mappings - mapping.update(**attn_mapping) - mapping.update(**ffn_mapping) - mapping.update(**mamba_mapping) - mapping.update(**moe_mapping) - - # Create transforms from specification - transforms = [] - - # Helper to create merge_qkv closure with proper layer index capture - def make_merge_qkv_fn(layer_idx): - def merge_qkv_fn(ctx, q, k, v): - return merge_qkv_for_puzzletron(ctx, q, k, v, idx=layer_idx) - - return merge_qkv_fn - - for spec in transform_specs: - if spec["transform_function"] == "merge_qkv_for_puzzletron": - # Fixed: proper closure to avoid variable capture issues - layer_idx = spec["kwargs"]["idx"] - transforms.append( - io.state_transform( - source_key=spec["source_key"], - target_key=spec["target_key"], - fn=make_merge_qkv_fn(layer_idx), - ) - ) - elif spec["transform_function"] == "merge_fc1_for_moe": - transforms.append( - io.state_transform( - source_key=spec["source_key"], - target_key=spec["target_key"], - fn=TransformFns.merge_fc1, - ) - ) - - # Add standard FC1 merge transform - transforms.append( - io.state_transform( - source_key=( - "model.layers.*.mlp.gate_proj.weight", - "model.layers.*.mlp.up_proj.weight", - ), - target_key="decoder.layers.*.mlp.linear_fc1.weight", - fn=TransformFns.merge_fc1, - ) - ) - return io.apply_transforms(source, target, mapping=mapping, transforms=transforms) - - @property - def tokenizer(self) -> "AutoTokenizer": - """Get the tokenizer for the HF model. - - Returns: - AutoTokenizer: Tokenizer instance initialized from the HF model's tokenizer - """ - from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer - - return AutoTokenizer(self.save_hf_tokenizer_assets(str(self)), trust_remote_code=True) - - @property - def config(self) -> PuzzletronNemotronModelConfig: - """Create a NeMo LlamaNemotronConfig from the HF model config. - - Translates the HF configuration parameters to the equivalent NeMo - configuration. - - Returns: - PuzzletronNemotronModelConfig: Puzzletron NeMo configuration for Llama models - """ - from transformers import AutoConfig, GenerationConfig - - source = AutoConfig.from_pretrained(str(self), trust_remote_code=True) - - # Validate that this is a proper Puzzletron-Nemotron checkpoint - assert getattr(source, "rope_scaling", None), ( - "Llama-Nemotron model should have rope scaling" - ) - assert getattr(source, "block_configs", None) is not None, ( - "Puzzletron-Nemotron model should be heterogeneous and have block configs" - ) - - adapted_cfg_dict = ( - PuzzletronNemotronModelConfig.create_adapted_config_dict_from_puzzletron_config(source) - ) - - try: - generation_config = GenerationConfig.from_pretrained(str(self)) - except Exception: - generation_config = None - - output = instantiate_nemo_config_from_adapted_dict( - adapted_cfg_dict, generation_config=generation_config - ) - return output - - -@io.model_exporter(PuzzletronLlamaNemotronModel, "hf") -class PuzzletronHFLlamaNemotronExporter( - io.ModelConnector[PuzzletronLlamaNemotronModel, "LlamaForCausalLM"] -): - """Exporter for converting NeMo Puzzletron Llama-Nemotron models to Hugging Face format. - - This class handles the conversion of NeMo's PuzzletronLlamaNemotronModel to Hugging Face's - LlamaForCausalLM format, including weight mapping and configuration translation. - It supports heterogeneous model architectures with Puzzletron-specific configurations. - - The exporter performs the following key operations: - 1. Initializes a Hugging Face model with appropriate configuration - 2. Maps weights from NeMo format to Hugging Face format - 3. Handles special cases for heterogeneous architectures with Mamba, MoE, and other custom layers - 4. Saves the converted model and tokenizer to the specified output path - - Attributes: - tokenizer: The tokenizer associated with the NeMo model - config: The configuration for the Hugging Face model - - Methods: - init: Initialize a Hugging Face model instance - apply: Convert and save the model to Hugging Face format - convert_state: Convert model weights from NeMo to Hugging Face format - """ - - # Base mapping for NeMo -> HF conversion (reversed from importer) - # Layernorm wildcards are replaced with per-layer mappings in convert_state() - default_mapping = { - "embedding.word_embeddings.weight": "model.embed_tokens.weight", - "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", - "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", - "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", - "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", - "decoder.final_layernorm.weight": "model.norm.weight", - "output_layer.weight": "lm_head.weight", - } - - @property - def config(self) -> "DeciLMConfig": - """Create a HF DeciLMConfig from the NeMo model config. - - This method constructs a DeciLMConfig for Puzzletron models by parsing the - heterogeneous_layers_config_encoded_json from the NeMo config and mapping - the fields to the HF DeciLM format. - - Returns: - DeciLMConfig: HF configuration for Puzzletron DeciLM models - """ - # Load the NeMo config - source_config = io.load_context(str(self), subpath="model.config") - - # Get preserved HF config metadata (stored as direct attribute) - # This enables lossless round-trip conversion HF → NeMo → HF - source_hf_config_metadata = getattr(source_config, "source_hf_config_metadata", None) or {} - - # Get EOS token ID(s) - prefer preserved value from source HF config metadata - # (HF supports multiple EOS tokens, NeMo tokenizer only has single eos_id) - eos_token_id = source_hf_config_metadata.get("eos_token_id", self.tokenizer.eos_id) - - # Use the shared conversion function - return convert_nemo_config_to_hf_decilm_config( - nemo_config=source_config, - vocab_size=self.tokenizer.vocab_size, - eos_token_id=eos_token_id, - bos_token_id=self.tokenizer.bos_id, - pad_token_id=getattr(self.tokenizer, "pad_id", None), - ) - - def init(self, dtype=torch.bfloat16, from_config=False, model_name=None) -> "LlamaForCausalLM": - """Initialize a Hugging Face LlamaForCausalLM model instance. - - This method creates a new Hugging Face model instance with the appropriate configuration - and data type. Puzzletron models always use from_config=True and create a DeciLMForCausalLM. - - Args: - dtype (torch.dtype, optional): Data type for model parameters. Defaults to torch.bfloat16. - from_config (bool, optional): Whether to initialize from config or load from pretrained. - For Puzzletron models, this should always be True. Defaults to False. - model_name (str, optional): Name of the pretrained model to load. Not used for Puzzletron - models since we generate the config dynamically. Defaults to None. - - Returns: - DeciLMForCausalLM: Initialized Hugging Face DeciLM model instance - - Raises: - ValueError: If model_name is provided (not supported for Puzzletron models) - """ - from transformers.modeling_utils import no_init_weights - - from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( - DeciLMForCausalLM, - ) - - with no_init_weights(): - if from_config: - # Puzzletron models: create DeciLMForCausalLM from self.config property - model = DeciLMForCausalLM(self.config) - model = model.to(dtype=dtype) - return model - else: - # Puzzletron models don't support loading from pretrained HF model cards - raise ValueError( - "Puzzletron models do not have official HF model cards. " - "Use from_config=True to create models from NeMo config." - ) - - def apply(self, output_path: Path, target_model_name=None) -> Path: - """Convert and save a NeMo Puzzletron Llama-Nemotron model to Hugging Face format. - - This method performs the complete conversion process: - 1. Loads the NeMo model checkpoint - 2. Creates the Hugging Face model from config - 3. Converts and transfers the weights - 4. Saves the converted model and tokenizer - - Args: - output_path (Path): Directory path where the converted model will be saved - target_model_name (str, optional): Not used for Puzzletron models. Kept for API compatibility. - - Returns: - Path: Path to the saved Hugging Face model directory - """ - logging.info("Loading Puzzletron Llama-Nemotron NeMo checkpoint..") - source, _ = self.nemo_load(str(self)) - - # Puzzletron models always use from_config=True to generate DeciLMConfig dynamically - target = self.init( - torch_dtype_from_mcore_config(source.config), - from_config=True, - model_name=None, - ) - target = self.convert_state(source, target) - - target = target.cpu() - target.save_pretrained(output_path) - self.tokenizer.tokenizer.save_pretrained(output_path) - - # Copy custom Python files needed for Puzzletron models - from modelopt.torch.puzzletron.export.MCore.llama_nemotron_utils import ( - copy_puzzletron_python_files_to_decilm_checkpoint, - embed_chat_template_in_tokenizer_config, - ) - - copy_puzzletron_python_files_to_decilm_checkpoint(str(output_path)) - - # Fix tokenizer: embed chat_template from .jinja file into tokenizer_config.json - # NeMo's HF → NeMo import extracts chat_template to .jinja but doesn't preserve - # it in tokenizer_config.json. We restore it here for accuracy parity. - embed_chat_template_in_tokenizer_config(str(self), str(output_path)) - - return output_path - - def convert_state(self, source: Any, target: Any) -> Any: - """Convert state dict from NeMo format to HF format. - - Maps the weights from the NeMo model to the HF model according to - the appropriate mapping scheme for Puzzletron models. - - This method follows the same pattern as the importer but with reversed mappings: - 1. Start with default mapping - 2. Remove layernorm wildcards (will be replaced with per-layer mappings) - 3. Build per-layer specific mappings using helper function and reverse them - 4. Create transforms for weight conversions - - Args: - source: Source NeMo model - target: Target HF model - - Returns: - The target model with weights transferred from source - """ - mapping = self.default_mapping.copy() - - # Handle LayerNorm bias if present - if source.config.normalization == "LayerNorm": - mapping["decoder.final_layernorm.bias"] = "model.norm.bias" - - # Handle tied embeddings - if getattr(source.config, "share_embeddings_and_output_weights", False): - # Remove output_layer mapping if embeddings are tied - if "output_layer.weight" in mapping: - del mapping["output_layer.weight"] - - # Build per-layer specific mappings for heterogeneous support - attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs = ( - _build_puzzletron_mappings_and_transforms(source.config) - ) - - # Remove layernorm wildcards from default_mapping - these will be replaced with - # specific per-layer mappings based on each layer's architecture. - for pattern in [ - "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", - "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", - ]: - if pattern in mapping: - del mapping[pattern] - - # For exporter: reverse all mappings (HF -> NeMo becomes NeMo -> HF) - attn_mapping = {v: k for k, v in attn_mapping.items()} - ffn_mapping = {v: k for k, v in ffn_mapping.items()} - mamba_mapping = {v: k for k, v in mamba_mapping.items()} - moe_mapping = {v: k for k, v in moe_mapping.items()} - - # Add all layer-specific mappings - mapping.update(**attn_mapping) - mapping.update(**ffn_mapping) - mapping.update(**mamba_mapping) - mapping.update(**moe_mapping) - - # Create transforms from specifications (reversed for exporter) - transforms = [] - - # Helper to create split_qkv closure with proper layer index capture - def make_split_qkv_fn(layer_idx): - def split_qkv_fn(ctx, qkv): - return split_qkv_for_puzzletron(ctx, qkv, idx=layer_idx) - - return split_qkv_fn - - for spec in transform_specs: - if spec["transform_function"] == "merge_qkv_for_puzzletron": - # For exporter: split QKV (NeMo -> HF) - layer_idx = spec["kwargs"]["idx"] - transforms.append( - io.state_transform( - source_key=spec["target_key"], # NeMo key - target_key=spec["source_key"], # HF key - fn=make_split_qkv_fn(layer_idx), - ) - ) - elif spec["transform_function"] == "merge_fc1_for_moe": - # For exporter: split FC1 for MoE (NeMo -> HF) - transforms.append( - io.state_transform( - source_key=spec["target_key"], # NeMo key - target_key=spec["source_key"], # HF key - fn=TransformFns.split_fc1, - ) - ) - - # Add standard transforms for FC1 splitting and padding pruning - transforms.extend( - [ - io.state_transform( - source_key="decoder.layers.*.mlp.linear_fc1.weight", - target_key=( - "model.layers.*.mlp.gate_proj.weight", - "model.layers.*.mlp.up_proj.weight", - ), - fn=TransformFns.split_fc1, - ), - io.state_transform( - source_key="embedding.word_embeddings.weight", - target_key="model.embed_tokens.weight", - fn=TransformFns.prune_padding, - ), - io.state_transform( - source_key="output_layer.weight", - target_key="lm_head.weight", - fn=TransformFns.prune_padding, - ), - ] - ) - - return io.apply_transforms( - source, - target, - mapping=mapping, - transforms=transforms, - ) - - @property - def tokenizer(self) -> "TokenizerSpec": - """Get the tokenizer from the NeMo model. - - Returns: - TokenizerSpec: Tokenizer from the NeMo model - """ - return io.load_context(str(self), subpath="model").tokenizer - - -__all__ = [ - "PuzzletronLlamaNemotronModel", -] diff --git a/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py b/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py deleted file mode 100644 index 8d01ec9537..0000000000 --- a/modelopt/torch/puzzletron/export/MCore/llama_nemotron_utils.py +++ /dev/null @@ -1,729 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from dataclasses import asdict -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from megatron.core.transformer.spec_utils import ModuleSpec -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import ( - AutoTokenizer as NemoAutoTokenizer, -) -from nemo.collections.llm.gpt.model.base import GPTModel -from nemo.collections.llm.gpt.model.llama_nemotron import ( - HFLlamaNemotronImporter, - PuzzletronNemotronModelConfig, -) -from nemo.lightning import io, teardown -from nemo.lightning.io.state import TransformFns -from nemo.lightning.pytorch.utils import dtype_from_str -from nemo.utils.import_utils import safe_import -from transformers import AutoModelForCausalLM, AutoTokenizer - -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( - PuzzletronAttentionConfig, - PuzzletronHeterogeneousTransformerConfig, - PuzzletronMLPConfig, - get_gpt_heterogeneous_layer_spec_puzzletron, -) - - -def convert_attention_config_from_cfg_object(attention_config, num_attention_heads, head_dim): - for unsupported_key in [ - "llama4", - "num_sink_tokens", - "sparsify", - "unshifted_sink", - "use_prefill_window_in_sink_attention", - ]: - if hasattr(attention_config, unsupported_key) and getattr( - attention_config, unsupported_key - ) not in [ - None, - False, - ]: - # - # if attention_config.get(unsupported_key, None) not in [None, False]: - raise NotImplementedError(f"{unsupported_key} is not supported") - window_size = attention_config.window_size if hasattr(attention_config, "window_size") else None - if window_size is not None: - window_size = (window_size, 0) - is_mamba = attention_config.mamba if hasattr(attention_config, "mamba") else False - n_heads_in_group = ( - attention_config.n_heads_in_group if hasattr(attention_config, "n_heads_in_group") else 1 - ) - if n_heads_in_group is None: - n_heads_in_group = 1 - return asdict( - PuzzletronAttentionConfig( - no_op=attention_config.no_op if hasattr(attention_config, "no_op") else False, - replace_with_linear=( - attention_config.replace_with_linear - if hasattr(attention_config, "replace_with_linear") - else False - ), - num_attention_heads=num_attention_heads, - num_query_groups=num_attention_heads // n_heads_in_group, - kv_channels=head_dim, - window_size=window_size, - multi_latent_attention=False, - is_mamba=is_mamba, - mamba_state_dim=( - attention_config.mamba.state_dim - if is_mamba and hasattr(attention_config.mamba, "state_dim") - else 128 - ), - mamba_head_dim=( - attention_config.mamba.head_dim - if is_mamba and hasattr(attention_config.mamba, "head_dim") - else 64 - ), - mamba_num_groups=( - attention_config.mamba.num_groups - if is_mamba and hasattr(attention_config.mamba, "num_groups") - else 8 - ), - mamba_num_heads=( - attention_config.mamba.num_heads - if is_mamba and hasattr(attention_config.mamba, "num_heads") - else None - ), - ) - ) - - -def convert_mlp_config_from_cfg_object(mlp_config, parallel_blocks): - """Convert MLP config from HF format to NeMo format. - - Args: - mlp_config: The MLP configuration object from HF - parallel_blocks: Parallel blocks configuration (not currently supported) - """ - if parallel_blocks is not None: - raise NotImplementedError("parallel_blocks is not supported") - if not hasattr(mlp_config, "gated") or mlp_config.gated is False: - raise NotImplementedError("notgated MLP is not supported") - - # Validate this block's activation function - if not hasattr(mlp_config, "hidden_act"): - raise ValueError(f"MLP config must have hidden_act attribute") - # if mlp_config.hidden_act != block_hidden_act: - # raise ValueError(f"MLP config hidden_act mismatch: config has {mlp_config.hidden_act}, expected {block_hidden_act}") - - if hasattr(mlp_config, "sparsify") and mlp_config.sparsify is not None: - raise NotImplementedError("sparsify is not supported") - is_moe = hasattr(mlp_config, "moe") and mlp_config.moe is not None - # Note: hidden_act is validated above but not stored in PuzzletronMLPConfig - # It will be used at the call site for the NeMo model config - return asdict( - PuzzletronMLPConfig( - no_op=mlp_config.no_op if hasattr(mlp_config, "no_op") else False, - replace_with_linear=mlp_config.replace_with_linear - if hasattr(mlp_config, "replace_with_linear") - else False, - ffn_hidden_size=mlp_config.intermediate_size - if hasattr(mlp_config, "intermediate_size") - else None, - num_moe_experts=( - mlp_config.moe.num_local_experts - if is_moe and hasattr(mlp_config.moe, "num_local_experts") - else None - ), - moe_shared_expert_intermediate_size=( - mlp_config.moe.shared_expert_intermediate_dim - if is_moe and hasattr(mlp_config.moe, "shared_expert_intermediate_dim") - else None - ), - moe_ffn_hidden_size=( - mlp_config.moe.expert_intermediate_dim - if is_moe and hasattr(mlp_config.moe, "expert_intermediate_dim") - else None - ), - moe_router_topk=( - mlp_config.moe.num_experts_per_tok - if is_moe and hasattr(mlp_config.moe, "num_experts_per_tok") - else 2 - ), - ) - ) - - -def convert_nemo_config_to_hf_decilm_config( - nemo_config: "PuzzletronNemotronModelConfig", - vocab_size: int, - eos_token_id: Union[int, List[int], None] = None, - bos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, -) -> "DeciLMConfig": - """Convert a NeMo PuzzletronNemotronModelConfig to HF DeciLMConfig. - - This function extracts the conversion logic from the exporter so it can be - used in unit tests without requiring file I/O. - - Args: - nemo_config: The NeMo config to convert - vocab_size: Vocabulary size for the HF config - eos_token_id: EOS token ID(s). Can be int or list of ints. - bos_token_id: BOS token ID - pad_token_id: PAD token ID - - Returns: - DeciLMConfig: The equivalent HF config - """ - - # Get preserved HF config metadata (stored as direct attribute) - # This enables lossless round-trip conversion HF → NeMo → HF - source_hf_config_metadata = getattr(nemo_config, "source_hf_config_metadata", None) or {} - - # Parse the heterogeneous layers config from JSON - block_configs = [] - - if ( - hasattr(nemo_config, "heterogeneous_layers_config_encoded_json") - and nemo_config.heterogeneous_layers_config_encoded_json - ): - try: - heterogeneous_config = json.loads(nemo_config.heterogeneous_layers_config_encoded_json) - raw_block_configs = heterogeneous_config.get("block_configs", []) - - for i, raw_block_config in enumerate(raw_block_configs): - attn_block = raw_block_config.get("attention", {}) - mlp_block = raw_block_config.get("mlp", {}) - - # Configure attention - attention_config = { - "no_op": attn_block.get("no_op", False), - "replace_with_linear": attn_block.get("replace_with_linear", False), - "sparsify": attn_block.get("sparsify", None), - "n_heads_in_group": attn_block.get( - "num_attention_heads", nemo_config.num_attention_heads - ) - // attn_block.get("num_query_groups", nemo_config.num_query_groups), - "window_length": attn_block.get("window_size", None), - "num_sink_tokens": attn_block.get("num_sink_tokens", None), - "use_prefill_window_in_sink_attention": attn_block.get( - "use_prefill_window_in_sink_attention", False - ), - "unshifted_sink": attn_block.get("unshifted_sink", False), - } - - # Handle Mamba: convert from NeMo flat structure to HF nested structure - if attn_block.get("is_mamba", False): - attention_config["mamba"] = { - "state_dim": attn_block.get("mamba_state_dim", 128), - "num_heads": attn_block.get( - "mamba_num_heads", nemo_config.num_attention_heads - ), - "head_dim": attn_block.get("mamba_head_dim", 64), - "num_groups": attn_block.get("mamba_num_groups", 8), - } - else: - attention_config["mamba"] = None - - # Handle Llama4: pass through as dict if present - attention_config["llama4"] = attn_block.get("llama4", None) - - # Configure FFN - ffn_config = { - "no_op": mlp_block.get("no_op", False), - "replace_with_linear": mlp_block.get("replace_with_linear", False), - "sparsify": mlp_block.get("sparsify", None), - "intermediate_size": mlp_block.get( - "ffn_hidden_size", nemo_config.ffn_hidden_size - ), - "gated": True, # Puzzletron uses gated activations - # Use the activation function name extracted from this block's config - "hidden_act": mlp_block.get("hidden_act", None), - } - - # Handle MoE: convert from NeMo flat structure to HF nested structure - num_moe_experts = mlp_block.get("num_moe_experts", None) - if num_moe_experts is not None: - ffn_config["moe"] = { - "num_local_experts": num_moe_experts, - "num_experts_per_tok": mlp_block.get("moe_router_topk", 1), - "expert_intermediate_dim": mlp_block.get("moe_ffn_hidden_size", 8192), - "shared_expert_intermediate_dim": mlp_block.get( - "moe_shared_expert_intermediate_size", 8192 - ), - } - else: - ffn_config["moe"] = None - - block_configs.append({"attention": attention_config, "ffn": ffn_config}) - except (json.JSONDecodeError, KeyError) as e: - raise ValueError(f"Could not parse heterogeneous config JSON: {e}") - else: - raise ValueError("No block configs found in source configuration") - - # Create rope scaling config - rope_scaling = { - "factor": nemo_config.scale_factor, - "low_freq_factor": getattr(nemo_config, "low_freq_factor", 1.0), - "high_freq_factor": getattr(nemo_config, "high_freq_factor", 4.0), - "original_max_position_embeddings": getattr(nemo_config, "old_context_len", 8192), - "rope_type": "llama3", - } - - # Get EOS token ID(s) - prefer preserved value from source HF config metadata or provided value - if eos_token_id is None: - eos_token_id = source_hf_config_metadata.get("eos_token_id", None) - - # Create DeciLM config - hf_config = DeciLMConfig( - block_configs=block_configs, - hidden_size=nemo_config.hidden_size, - max_position_embeddings=nemo_config.seq_length, - num_attention_heads=nemo_config.num_attention_heads, - num_hidden_layers=nemo_config.num_layers, - tie_word_embeddings=nemo_config.share_embeddings_and_output_weights, - vocab_size=vocab_size, - rms_norm_eps=nemo_config.layernorm_epsilon, - attention_bias=getattr(nemo_config, "attention_bias", False), - o_proj_bias=getattr( - nemo_config, "o_proj_bias", getattr(nemo_config, "attention_bias", False) - ), - rope_theta=nemo_config.rotary_base, - rope_scaling=rope_scaling, - position_embedding_type="rope", - architectures=["DeciLMForCausalLM"], - model_type="nemotron-nas", - eos_token_id=eos_token_id, - bos_token_id=bos_token_id, - pad_token_id=pad_token_id, - head_dim=nemo_config.kv_channels, - # Restore auto_map from preserved metadata (needed for trust_remote_code loading) - auto_map=source_hf_config_metadata.get( - "auto_map", - { - "AutoConfig": "configuration_decilm.DeciLMConfig", - "AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM", - }, - ), - # Restore dtype field from preserved metadata - dtype=source_hf_config_metadata.get("dtype", "bfloat16"), - ) - - return hf_config - - -def _config_to_dict(config) -> Dict[str, Any]: - """Convert a config object to a dictionary. - - Args: - config: Either an object with attributes or already a dictionary - - Returns: - Dictionary representation of the config - """ - if isinstance(config, dict): - return config - return vars(config) - - -def _build_puzzletron_mappings_and_transforms( - source_config: PuzzletronHeterogeneousTransformerConfig, -) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str], Dict[str, str], List[Dict[str, Any]]]: - """Build mappings and transform specifications for Puzzletron heterogeneous models. - - Args: - source_config: The Puzzletron heterogeneous transformer configuration - - Returns: - Tuple containing: - - attn_mapping: Attention layer mappings - - ffn_mapping: FFN layer mappings - - mamba_mapping: Mamba layer mappings - - moe_mapping: MoE layer mappings - - transform_specs: List of transform specifications with source_key, target_key, transform_function - """ - attn_mapping = {} - ffn_mapping = {} - mamba_mapping = {} - moe_mapping = {} - transform_specs = [] - - # Determine config type and extract block configs - is_hf_config = hasattr(source_config, "block_configs") - is_nemo_config = ( - hasattr(source_config, "heterogeneous_layers_config_encoded_json") - and source_config.heterogeneous_layers_config_encoded_json - ) - assert not (is_hf_config and is_nemo_config), "Cannot have both HF and NeMo config" - - if is_hf_config: - # HF config case (importer) - block_configs = source_config.block_configs - elif is_nemo_config: - # NeMo config case (exporter) - parse JSON - try: - heterogeneous_config = json.loads( - source_config.heterogeneous_layers_config_encoded_json - ) - block_configs = heterogeneous_config.get("block_configs", []) - except (json.JSONDecodeError, KeyError): - block_configs = [] - else: - block_configs = [] - - # Check if we found any block configs - if not block_configs: - raise ValueError( - "No block configs found in source configuration. " - "Expected either 'block_configs' attribute (HF config) or " - "'heterogeneous_layers_config_encoded_json' attribute (NeMo config) with valid block configs." - ) - - # TODO it is better (more stable) to use target.config.block_configs - for idx, block_config in enumerate(block_configs): - # Convert block config to dictionary - block_dict = _config_to_dict(block_config) - - # Extract attention and FFN configs (handle both HF "ffn" and NeMo "mlp" keys) - attn = block_dict.get("attention") - ffn = block_dict.get("ffn") or block_dict.get("mlp") - - # Convert sub-configs to dictionaries - attn_dict = _config_to_dict(attn) if attn else {} - ffn_dict = _config_to_dict(ffn) if ffn else {} - - # Process attention config - # Handle both HF (mamba) and NeMo (is_mamba) keys - is_mamba = attn_dict.get("mamba") or attn_dict.get("is_mamba") - - if not attn or attn_dict.get("no_op"): - value = None - elif attn_dict.get("replace_with_linear"): - value = f"decoder.layers.{idx}.self_attention.layer_norm_weight" - elif is_mamba is not None: - value = f"decoder.layers.{idx}.self_attention.in_proj.layer_norm_weight" - for mamba_key in [ - "dt_bias", - "A_log", - "D", - "in_proj.weight", - "conv1d.weight", - "conv1d.bias", - "norm.weight", - "out_proj.weight", - ]: - mamba_mapping[f"model.layers.{idx}.self_attn.mamba_mixer.{mamba_key}"] = ( - f"decoder.layers.{idx}.self_attention.{mamba_key}" - ) - else: - value = f"decoder.layers.{idx}.self_attention.linear_qkv.layer_norm_weight" - # Store transform spec for QKV merging - transform_specs.append( - { - "source_key": ( - f"model.layers.{idx}.self_attn.q_proj.weight", - f"model.layers.{idx}.self_attn.k_proj.weight", - f"model.layers.{idx}.self_attn.v_proj.weight", - ), - "target_key": f"decoder.layers.{idx}.self_attention.linear_qkv.weight", - "transform_function": "merge_qkv_for_puzzletron", - "kwargs": {"idx": idx}, - } - ) - - if value is not None: - attn_mapping[f"model.layers.{idx}.input_layernorm.weight"] = value - - # Process FFN config - # Handle both HF (moe, moe.shared_expert_intermediate_dim) and NeMo (num_moe_experts, moe_shared_expert_intermediate_size) keys - moe_config = ffn_dict.get("moe") or ffn_dict.get("num_moe_experts") - shared_expert_size = None - if moe_config: - # Convert moe_config to dict if it's an object (HF case) - moe_dict = ( - _config_to_dict(moe_config) if not isinstance(moe_config, (int, type(None))) else {} - ) - shared_expert_size = moe_dict.get("shared_expert_intermediate_dim") or ffn_dict.get( - "moe_shared_expert_intermediate_size" - ) - - if not ffn or ffn_dict.get("no_op"): - value = None - elif ffn_dict.get("replace_with_linear"): - value = f"decoder.layers.{idx}.mlp.layer_norm_weight" - elif moe_config is not None: - value = f"decoder.layers.{idx}.pre_mlp_layernorm.weight" - moe_mapping[f"model.layers.{idx}.mlp.router.weight"] = ( - f"decoder.layers.{idx}.mlp.router.weight" - ) - # Store transform spec for MoE expert FC1 merging - transform_specs.append( - { - "source_key": ( - f"model.layers.{idx}.mlp.experts.*.gate_proj.weight", - f"model.layers.{idx}.mlp.experts.*.up_proj.weight", - ), - "target_key": f"decoder.layers.{idx}.mlp.experts.local_experts.*.linear_fc1.weight", - "transform_function": "merge_fc1_for_moe", - "kwargs": {}, - } - ) - moe_mapping[f"model.layers.{idx}.mlp.experts.*.down_proj.weight"] = ( - f"decoder.layers.{idx}.mlp.experts.local_experts.*.linear_fc2.weight" - ) - # Check for shared expert - if shared_expert_size not in [None, 0]: - # Store transform spec for MoE shared expert FC1 merging - transform_specs.append( - { - "source_key": ( - f"model.layers.{idx}.mlp.shared_expert.gate_proj.weight", - f"model.layers.{idx}.mlp.shared_expert.up_proj.weight", - ), - "target_key": f"decoder.layers.{idx}.mlp.shared_experts.linear_fc1.weight", - "transform_function": "merge_fc1_for_moe", - "kwargs": {}, - } - ) - moe_mapping[f"model.layers.{idx}.mlp.shared_expert.down_proj.weight"] = ( - f"decoder.layers.{idx}.mlp.shared_experts.linear_fc2.weight" - ) - else: - value = f"decoder.layers.{idx}.mlp.linear_fc1.layer_norm_weight" - - if value is not None: - ffn_mapping[f"model.layers.{idx}.post_attention_layernorm.weight"] = value - - return attn_mapping, ffn_mapping, mamba_mapping, moe_mapping, transform_specs - - -def merge_qkv_for_puzzletron( - ctx: io.state.TransformCTX, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - idx: Optional[int] = None, -): - """ - Merge q, k, v to interleave-concatenated qkv. - - Modified version of nemo.lightning.io.state.TransformFns.merge_qkv for Puzzletron - - idx can be provided to fetch megatron_config for a specific layer - - heads_per_group is derived from the shape of q and k, instead of calculating (head_num // num_query_groups) from config values - - num_query_groups is not fetched from a global config value, but calculated from head_num and heads_per_group - - Example: import HF {q|k|v}_proj to layer linear_qkv - """ - if idx is not None: - megatron_config = ctx.target.decoder.layers[idx].config - else: - megatron_config = ctx.target.config - head_num = megatron_config.num_attention_heads - heads_per_group = ( - q.shape[0] // k.shape[0] - ) # NOTE: This is important to support heterogeneous attention - num_query_groups = head_num // heads_per_group - hidden_size = megatron_config.hidden_size - head_size = megatron_config.kv_channels - old_tensor_shape = q.size() - new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] - new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] - - q = q.view(*new_q_tensor_shape) - k = k.view(*new_kv_tensor_shape) - v = v.view(*new_kv_tensor_shape) - - qkv_weights_l = [] - for i in range(num_query_groups): - qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) - qkv_weights_l.append(k[i : i + 1, :, :]) - qkv_weights_l.append(v[i : i + 1, :, :]) - qkv_weights = torch.cat(qkv_weights_l) - assert qkv_weights.ndim == 3, qkv_weights.shape - assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape - assert qkv_weights.shape[1] == head_size, qkv_weights.shape - assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape - - qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) - - return qkv_weights - - -def split_qkv_for_puzzletron( - ctx: io.state.TransformCTX, qkv: torch.Tensor, idx: Optional[int] = None -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Split interleave-concatenated qkv to separate q, k, v. - - Inverse operation of merge_qkv_for_puzzletron for Puzzletron - - idx can be provided to fetch megatron_config for a specific layer - - heads_per_group is derived from the shape of qkv, instead of calculating from config values - - num_query_groups is not fetched from a global config value, but calculated from head_num and heads_per_group - - Example: export NeMo layer linear_qkv to HF {q|k|v}_proj - """ - if idx is not None: - megatron_config = ctx.source.decoder.layers[idx].config - else: - megatron_config = ctx.source.config - - head_num = megatron_config.num_attention_heads - head_size = megatron_config.kv_channels - # hidden_size = megatron_config.hidden_size - - # Calculate qkv_total_dim from the actual qkv tensor shape - # qkv shape is [head_size * (head_num + 2 * num_query_groups), hidden_size] - qkv_total_dim = qkv.shape[0] // head_size - num_query_groups = (qkv_total_dim - head_num) // 2 - heads_per_group = head_num // num_query_groups - - # Reshape qkv to 3D: [qkv_total_dim, head_size, hidden_size] - qkv = qkv.reshape([qkv_total_dim, head_size, -1]) - - # when converting base model (linear_qkv), hidden size = megatron_config.hidden_size - # when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r - actual_hidden_size = qkv.size(-1) - - # Create slice indices for q, k, v - q_slice = torch.cat( - [ - torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) - for i in range(num_query_groups) - ] - ) - k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) - v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) - - q_proj = qkv[q_slice].reshape(-1, actual_hidden_size).cpu() - k_proj = qkv[k_slice].reshape(-1, actual_hidden_size).cpu() - v_proj = qkv[v_slice].reshape(-1, actual_hidden_size).cpu() - - return q_proj, k_proj, v_proj - - -def dtype_from_dict(config_dict): - """ - Extracts torch dtype from a HF config. - Handles both 'torch_dtype' (old format) and 'dtype' (new format). - """ - # Try torch_dtype first (old format), then dtype (new format) - if "torch_dtype" in config_dict: - torch_dtype = config_dict["torch_dtype"] - elif "dtype" in config_dict: - torch_dtype = config_dict["dtype"] - else: - raise ValueError("Expected config dict to have attr `torch_dtype` or `dtype`") - - if isinstance(torch_dtype, torch.dtype): - return torch_dtype - elif isinstance(torch_dtype, str): - return dtype_from_str(torch_dtype) - else: - raise ValueError(f"dtype is not of type str/torch.dtype, got {type(torch_dtype)}") - - -def copy_puzzletron_python_files_to_decilm_checkpoint(output_path: str) -> None: - """Copy custom Python files from puzzle_tools package to output directory. - - Puzzletron models require custom Python files (configuration_decilm.py, - modeling_decilm.py, etc.) to be present in the checkpoint directory for - loading with transformers.AutoModel. - - This function copies all Python files from puzzle_tools/deci_lm_hf_code/ - to ensure the exported checkpoint is fully functional. - - Args: - output_path: Directory where HF model is being saved - """ - import logging - import shutil - from pathlib import Path - - # Get the puzzle_tools/deci_lm_hf_code directory - # Navigate from this file: export/MCore/llama_nemotron_utils.py -> v1/puzzle_tools/deci_lm_hf_code/ - package_dir = Path(__file__).parent.parent.parent / "puzzle_tools" / "deci_lm_hf_code" - - if not package_dir.exists(): - logging.warning( - f"Custom files directory not found: {package_dir}. " - f"Exported checkpoint may not be loadable without these files." - ) - return - - # Copy all Python files from the package - output_dir = Path(output_path) - copied_files = [] - for src_file in package_dir.glob("*.py"): - dest_file = output_dir / src_file.name - shutil.copy2(src_file, dest_file) - copied_files.append(src_file.name) - - logging.info(f"Copied {len(copied_files)} custom Python files to {output_path}") - logging.debug(f"Custom files copied: {', '.join(sorted(copied_files)[:5])}...") # Show first 5 - - -def embed_chat_template_in_tokenizer_config(nemo_checkpoint_path: str, output_path: str) -> None: - """Embed chat_template from .jinja file into tokenizer_config.json. - - NeMo's HF → NeMo import extracts chat_template to a separate .jinja file - but doesn't preserve it in tokenizer_config.json. This causes accuracy drops - in evaluation. This function restores it by: - 1. Reading chat_template.jinja from the NeMo checkpoint - 2. Embedding it into the exported tokenizer_config.json - - Args: - nemo_checkpoint_path: Path to the NeMo checkpoint (.nemo file/directory) - output_path: Directory where HF model is being saved - """ - import logging - from pathlib import Path - - # Path to NeMo checkpoint tokenizer files - nemo_checkpoint = Path(nemo_checkpoint_path) - nemo_chat_template_jinja = ( - nemo_checkpoint / "context" / "nemo_tokenizer" / "chat_template.jinja" - ) - - # Path to exported tokenizer config - output_dir = Path(output_path) - output_tokenizer_config = output_dir / "tokenizer_config.json" - - # Check if both files exist - if not nemo_chat_template_jinja.exists(): - logging.debug( - f"No chat_template.jinja found in NeMo checkpoint at {nemo_chat_template_jinja}" - ) - return - - if not output_tokenizer_config.exists(): - logging.warning(f"tokenizer_config.json not found at {output_tokenizer_config}") - return - - # Read chat_template from .jinja file - chat_template_content = nemo_chat_template_jinja.read_text() - - # Load tokenizer_config.json - with open(output_tokenizer_config, "r") as f: - tokenizer_config = json.load(f) - - # Check if chat_template is already embedded (shouldn't be, but be safe) - if "chat_template" in tokenizer_config: - logging.debug("chat_template already embedded in tokenizer_config.json, skipping") - return - - # Embed the chat_template - tokenizer_config["chat_template"] = chat_template_content - - # Save updated tokenizer_config.json - with open(output_tokenizer_config, "w") as f: - json.dump(tokenizer_config, f, indent=2, ensure_ascii=False) - - logging.info(f"✓ Embedded chat_template from NeMo checkpoint into tokenizer_config.json") - logging.debug(f" Template length: {len(chat_template_content)} characters") diff --git a/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py b/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py deleted file mode 100644 index 11a8798ba6..0000000000 --- a/modelopt/torch/puzzletron/export/MCore/puzzletron_hf_config_utils.py +++ /dev/null @@ -1,142 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import asdict - -import torch -from megatron.core.transformer.spec_utils import ModuleSpec -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import ( - AutoTokenizer as NemoAutoTokenizer, -) -from nemo.collections.llm.gpt.model.base import GPTModel -from nemo.collections.llm.gpt.model.llama_nemotron import HFLlamaNemotronImporter -from nemo.lightning import io, teardown -from nemo.lightning.io.state import TransformFns -from nemo.utils.import_utils import safe_import -from transformers import AutoModelForCausalLM, AutoTokenizer - -from modelopt.torch.puzzletron.export.MCore.puzzletron_layer_specs import ( - PuzzletronAttentionConfig, - PuzzletronHeterogeneousTransformerConfig, - PuzzletronMLPConfig, - get_gpt_heterogeneous_layer_spec_puzzletron, -) - - -def convert_attention_config_from_cfg_object(attention_config, num_attention_heads, head_dim): - for unsupported_key in [ - "llama4", - "num_sink_tokens", - "sparsify", - "unshifted_sink", - "use_prefill_window_in_sink_attention", - ]: - if hasattr(attention_config, unsupported_key) and getattr( - attention_config, unsupported_key - ) not in [ - None, - False, - ]: - # - # if attention_config.get(unsupported_key, None) not in [None, False]: - raise NotImplementedError(f"{unsupported_key} is not supported") - window_size = attention_config.window_size if hasattr(attention_config, "window_size") else None - if window_size is not None: - window_size = (window_size, 0) - is_mamba = attention_config.mamba if hasattr(attention_config, "mamba") else False - n_heads_in_group = ( - attention_config.n_heads_in_group if hasattr(attention_config, "n_heads_in_group") else 1 - ) - if n_heads_in_group is None: - n_heads_in_group = 1 - return asdict( - PuzzletronAttentionConfig( - no_op=attention_config.no_op if hasattr(attention_config, "no_op") else False, - replace_with_linear=( - attention_config.replace_with_linear - if hasattr(attention_config, "replace_with_linear") - else False - ), - num_attention_heads=num_attention_heads, - num_query_groups=num_attention_heads // n_heads_in_group, - kv_channels=head_dim, - window_size=window_size, - multi_latent_attention=False, - is_mamba=is_mamba, - mamba_state_dim=( - attention_config.mamba.state_dim - if is_mamba and hasattr(attention_config.mamba, "state_dim") - else 128 - ), - mamba_head_dim=( - attention_config.mamba.head_dim - if is_mamba and hasattr(attention_config.mamba, "head_dim") - else 64 - ), - mamba_num_groups=( - attention_config.mamba.num_groups - if is_mamba and hasattr(attention_config.mamba, "num_groups") - else 8 - ), - mamba_num_heads=( - attention_config.mamba.num_heads - if is_mamba and hasattr(attention_config.mamba, "num_heads") - else None - ), - ) - ) - - -def convert_mlp_config_from_cfg_object(mlp_config, parallel_blocks, default_hidden_act): - if parallel_blocks is not None: - raise NotImplementedError("parallel_blocks is not supported") - if not hasattr(mlp_config, "gated") or mlp_config.gated is False: - raise NotImplementedError("non-gated MLP is not supported") - if not hasattr(mlp_config, "hidden_act") or mlp_config.hidden_act not in [default_hidden_act]: - raise NotImplementedError(f"all mlps must have the same activation ({default_hidden_act})") - if hasattr(mlp_config, "sparsify") and mlp_config.sparsify is not None: - raise NotImplementedError("sparsify is not supported") - is_moe = hasattr(mlp_config, "moe") and mlp_config.moe is not None - return asdict( - PuzzletronMLPConfig( - no_op=mlp_config.no_op if hasattr(mlp_config, "no_op") else False, - replace_with_linear=mlp_config.replace_with_linear - if hasattr(mlp_config, "replace_with_linear") - else False, - ffn_hidden_size=mlp_config.intermediate_size - if hasattr(mlp_config, "intermediate_size") - else None, - num_moe_experts=( - mlp_config.moe.num_local_experts - if is_moe and hasattr(mlp_config.moe, "num_local_experts") - else None - ), - moe_shared_expert_intermediate_size=( - mlp_config.moe.shared_expert_intermediate_dim - if is_moe and hasattr(mlp_config.moe, "shared_expert_intermediate_dim") - else None - ), - moe_ffn_hidden_size=( - mlp_config.moe.expert_intermediate_dim - if is_moe and hasattr(mlp_config.moe, "expert_intermediate_dim") - else None - ), - moe_router_topk=( - mlp_config.moe.num_experts_per_tok - if is_moe and hasattr(mlp_config.moe, "num_experts_per_tok") - else 2 - ), - ) - ) diff --git a/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py b/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py deleted file mode 100644 index ec011ff288..0000000000 --- a/modelopt/torch/puzzletron/export/MCore/puzzletron_layer_specs.py +++ /dev/null @@ -1,928 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from dataclasses import asdict, dataclass, field, fields -from pathlib import Path -from typing import Any, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.gpt.gpt_layer_specs import ( - LayerType, - LNImpl, - TransformerBlockSubmodules, - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, - get_num_layers_to_build, - get_transformer_layer_offset, -) -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.post_training.modelopt.layers import Linear -from megatron.core.process_groups_config import ModelCommProcessGroups -from megatron.core.quantization.utils import ( - kitchen_quantization_recipe_config, - load_quantization_recipe, -) -from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules -from megatron.core.tensor_parallel.layers import ( - ColumnParallelLinear, - RowParallelLinear, - _initialize_affine_weight_cpu, -) -from megatron.core.tensor_parallel.random import get_cuda_rng_tracker -from megatron.core.transformer import MLATransformerConfig, TransformerConfig -from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version - -# from megatron.core.activations import squared_relu #for megatron 0.14 version in future NeMo containers -from megatron.training.activations import squared_relu -from nemo.collections.llm.gpt.model.llama import Llama31Config70B -from packaging.version import Version as PkgVersion -from torch import Tensor -from torch.nn.parameter import Parameter - -try: - import transformer_engine as te # pylint: disable=unused-import - from megatron.core.extensions.transformer_engine import ( - TELayerNormColumnParallelLinear, - TELinear, - TENorm, - TERowParallelLinear, - _get_extra_te_kwargs, - ) - - HAVE_TE = True -except ImportError: - HAVE_TE = False - -# TODO: check sharded_state_dict_keys_map => only if TE is disabled -# TODO: parallel Blocks -# TODO: multimodal -# https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/vlm/neva/model/base.py -# https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/vlm/qwen2vl/model/base.py - - -# NOTE based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py#L144 -# TODO: what is the difference between this and the referenced one? -def _get_sharded_state_dict_keys_map( - block_config: "PuzzletronTransformerBlockConfig", use_transformer_engine: bool -): - """Generate a mapping of sharded state dictionary keys for Puzzletron transformer blocks. - - This function is a specialized version of the original Megatron-LM - `_get_sharded_state_dict_keys_map` function, adapted for Puzzletron's - heterogeneous transformer architecture with Mamba support. - - Key differences from the original: - - **Mamba Support**: Adds mapping for Mamba layers (`mixer.norm_`) - - **Enhanced Logic**: Uses if-elif-else structure instead of multiple if statements - - **No-op Handling**: Explicit handling of no-op attention and MLP cases - - **Simplified**: Removes `num_query_groups` check (handled in main logic) - - **Config Type**: Uses `PuzzletronTransformerBlockConfig` instead of `TransformerBlockConfig` - - Args: - block_config: Puzzletron transformer block configuration - use_transformer_engine: Whether to use Transformer Engine optimizations - - Returns: - dict: A dictionary mapping sharded state dictionary keys - """ - mapping = {} - if not use_transformer_engine: - if block_config.attention.replace_with_linear: - mapping.update({"input_layernorm.": "self_attention.layer_norm_"}) - elif block_config.attention.is_mamba: # Mamba, not sure about this - mapping.update({"input_layernorm.": "mixer.norm_"}) - elif not block_config.attention.no_op: # MHA and MLA - mapping.update({"input_layernorm.": "self_attention.linear_qkv.layer_norm_"}) - else: # No-op - pass - - if block_config.mlp.ffn_hidden_size is not None: # FFN - mapping.update({"pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_"}) - elif block_config.mlp.replace_with_linear: # Linear - mapping.update({"pre_mlp_layernorm.": "mlp.layer_norm_"}) - else: # No-op, MoE - pass - return mapping - - -# NOTE: new class -@dataclass -class PuzzletronSubblockConfig: - """Base configuration class for Puzzletron transformer subblocks. - - This is the base class for attention and MLP configurations in Puzzletron's - heterogeneous transformer architecture. It provides common functionality - for subblock configurations including no-op and linear replacement options. - - Key differences from the original Megatron-LM subblock configs: - - **Enhanced Building**: Uses `build_config_from_dict()` with main config fallback - - **Validation**: Includes `__post_init__()` validation for mutual exclusivity - - **Flexibility**: Supports both no-op and linear replacement modes - - Attributes: - no_op: Whether this subblock should be a no-op operation - replace_with_linear: Whether to replace the subblock with a single linear layer - """ - - no_op: bool = False - replace_with_linear: bool = False - - @classmethod - def build_config_from_dict( - cls, - subblock_config_dict: dict[str, Any], - main_config: "PuzzletronHeterogeneousTransformerConfig", - ): - field_names = {f.name for f in fields(cls)} - subblock_config_dict = {k: v for k, v in subblock_config_dict.items() if k in field_names} - # getting default values from the main config (if not overridden in the subblock config) - for field_name in field_names: - # note that MLA fields are also not in the main_config - if field_name not in subblock_config_dict and hasattr(main_config, field_name): - subblock_config_dict[field_name] = getattr(main_config, field_name) - return cls(**subblock_config_dict) - - def __post_init__(self) -> None: - assert not (self.no_op and self.replace_with_linear), ( - "at most one of no_op, replace_with_linear can be True" - ) - - -@dataclass -class PuzzletronAttentionConfig(PuzzletronSubblockConfig): - """Configuration parameters for the self-attention part of a Puzzletron transformer block. - - This class extends the original Megatron-LM AttentionConfig with support for - Mamba layers and enhanced Multi-Latent Attention (MLA) configurations. - - Key differences from the original AttentionConfig: - - **Mamba Support**: Adds `is_mamba` flag and Mamba-specific parameters - - **Enhanced MLA**: Extended MLA parameters with LoRA ranks and head dimensions - - **Context Parallelism**: Adds `cp_comm_type` for attention context parallelism - - **Validation**: Enhanced `__post_init__()` with Mamba-MLA mutual exclusivity check - - **Flexibility**: Supports MHA, MLA, and Mamba attention types in one config - - Attributes: - # MHA (Multi-Head Attention) parameters - num_attention_heads: Number of attention heads - num_query_groups: Number of query groups for grouped query attention - kv_channels: Key-value projection dimension - window_size: Sliding window size for local attention - - # MLA (Multi-Latent Attention) parameters - multi_latent_attention: Whether to use MLA instead of MHA - q_lora_rank: LoRA rank for query projections - kv_lora_rank: LoRA rank for key-value projections - qk_head_dim: Query-key head dimension - qk_pos_emb_head_dim: Query-key positional embedding head dimension - v_head_dim: Value head dimension - - # Context parallelism - cp_comm_type: Communication type for context parallelism - - # Mamba parameters - is_mamba: Whether to use Mamba instead of attention - mamba_state_dim: Mamba state dimension - mamba_head_dim: Mamba head dimension - mamba_num_groups: Number of groups in Mamba - mamba_num_heads: Number of heads in Mamba (auto-calculated if None) - """ - - # all attributes, except for is_mamba are part of TransformerConfig/MLATransformerConfig - # MHA - num_attention_heads: Optional[int] = None - num_query_groups: Optional[int] = None - kv_channels: Optional[int] = None - window_size: Optional[Tuple[int, int]] = None - # MLA (Note that for MLA we have to instantiate a MLATransformerConfig, since there is a isinstance in attention.py) - multi_latent_attention: bool = False - q_lora_rank: int = 512 - kv_lora_rank: int = 512 - qk_head_dim: int = 128 - qk_pos_emb_head_dim: int = 64 - v_head_dim: int = 128 - # for attention context parallelism (ignored for mamba) - cp_comm_type: str = "p2p" - # Mamba - is_mamba: bool = False # new - mamba_state_dim: int = 128 - mamba_head_dim: int = 64 - mamba_num_groups: int = 8 - mamba_num_heads: Optional[int] = None - - def __post_init__(self) -> None: - super().__post_init__() - if self.no_op or self.replace_with_linear: - self.is_mamba = False - self.num_attention_heads = 8 - self.multi_latent_attention = False - if self.is_mamba: - if self.num_attention_heads is None or self.num_attention_heads == 0: - self.num_attention_heads = 8 # to avoid division by zero - assert not (self.is_mamba and self.multi_latent_attention), ( - "Mamba and MLA cannot be used together" - ) - - -@dataclass -class PuzzletronMLPConfig(PuzzletronSubblockConfig): - """Configuration parameters for the MLP part of a Puzzletron transformer block. - - This class extends the original Megatron-LM MLPConfig with enhanced - Mixture of Experts (MoE) support and improved configuration building. - - Key differences from the original MLPConfig: - - **Enhanced MoE**: Extended MoE parameters with shared expert support - - **Validation**: Includes `__post_init__()` validation for no-op/linear modes - - **Building**: Uses `build_config_from_dict()` with main config fallback - - **Flexibility**: Supports standard MLP, MoE, no-op, and linear replacement modes - - Attributes: - # Standard MLP parameters - ffn_hidden_size: MLP intermediate size (hidden dimension) - - # MoE (Mixture of Experts) parameters - num_moe_experts: Number of expert networks in MoE - moe_shared_expert_intermediate_size: Size of shared expert intermediate layer - moe_ffn_hidden_size: Hidden size for MoE expert networks - moe_router_topk: Number of top-k experts to route tokens to - """ - - # all attributes are part of TransformerConfig - ffn_hidden_size: Optional[int] = None - # MoE - num_moe_experts: Optional[int] = None - moe_shared_expert_intermediate_size: Optional[int] = None - moe_ffn_hidden_size: Optional[int] = None - moe_router_topk: int = 2 - - def __post_init__(self) -> None: - super().__post_init__() - if self.no_op or self.replace_with_linear: - self.ffn_hidden_size = None - self.num_moe_experts = None - self.moe_ffn_hidden_size = None - - -# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/transformer/heterogeneous/heterogeneous_config.py#L134 -@dataclass -class PuzzletronTransformerBlockConfig: - """Configuration for a single Puzzletron transformer block in a heterogeneous model. - - This class represents the configuration for one transformer block, containing - both attention and MLP subblock configurations. It's based on the original - Megatron-LM TransformerBlockConfig but uses Puzzletron-specific subblock configs. - - Key differences from the original TransformerBlockConfig: - - **Puzzletron Subblocks**: Uses `PuzzletronAttentionConfig` and `PuzzletronMLPConfig` - - **Enhanced Building**: Uses `build_from_dict()` with main config fallback - - **Mamba Support**: Supports Mamba layers through attention config - - **MoE Support**: Enhanced MoE support through MLP config - - **Flexibility**: Supports all Puzzletron attention and MLP variants - - Attributes: - attention: Configuration for the attention subblock (MHA, MLA, or Mamba) - mlp: Configuration for the MLP subblock (standard MLP or MoE) - """ - - attention: PuzzletronAttentionConfig - mlp: PuzzletronMLPConfig - - @classmethod - def build_from_dict( - cls, block: dict[str, Any], main_config: "PuzzletronHeterogeneousTransformerConfig" - ): - if "mlp" in block: - mlp = block["mlp"] - elif "ffn" in block: - mlp = block["ffn"] - else: - raise ValueError(f"mlp/ffn not found in block: {block}") - - return cls( - attention=PuzzletronAttentionConfig.build_config_from_dict( - subblock_config_dict=block["attention"], main_config=main_config - ), - mlp=PuzzletronMLPConfig.build_config_from_dict( - subblock_config_dict=mlp, main_config=main_config - ), - ) - - -@dataclass -class PuzzletronMambaTransformerConfig(TransformerConfig): - """Configuration for Puzzletron Mamba-only transformer models. - - This class extends the base TransformerConfig for models that use - Mamba layers exclusively instead of attention mechanisms. It inherits - all standard transformer configuration parameters from TransformerConfig. - - Key differences from standard TransformerConfig: - - **Mamba Focus**: Designed specifically for Mamba-based architectures - - **Inheritance**: Inherits all standard transformer parameters - - **Simplicity**: Currently a pass-through class for future Mamba-specific extensions - - Note: This class is currently minimal and inherits all functionality - from the base TransformerConfig. Future versions may add Mamba-specific - configuration parameters as needed. - """ - - -# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/transformer/heterogeneous/heterogeneous_config.py#L147 -@dataclass -class PuzzletronHeterogeneousTransformerConfig(TransformerConfig): - """Configuration object for Puzzletron heterogeneous transformers. - - This class extends the original Megatron-LM HeterogeneousTransformerConfig with - enhanced support for Mamba layers and improved configuration management. - - Key differences from the original HeterogeneousTransformerConfig: - - **Mamba Support**: Adds Mamba-specific parameters for state-space models - - **Enhanced Block Configs**: Uses `PuzzletronTransformerBlockConfig` with Mamba support - - **Improved Building**: Enhanced `__post_init__()` with better config validation - - **Flexibility**: Supports all Puzzletron attention and MLP variants - - Heterogeneous models refer to transformer architectures where individual layers can differ - in configuration. Specifically: - - Attention layers can be MHA, MLA, Mamba, Linear, or No-op (all with their own config) - - MLP layers can be MoE, MLP, Linear, or No-op (all with their own config) - - Layers can have parallel blocks that run simultaneously and sum their outputs - - Mamba Parameters (shared across all Mamba layers): - d_conv: Convolution dimension for Mamba - expand: Expansion factor for Mamba hidden dimension - D_has_hdim: Whether D matrix has hidden dimension - rmsnorm: Whether to use RMS normalization - norm_before_gate: Whether to normalize before gating - dt_min/max/scale: Delta time parameters for Mamba - bias/conv_bias: Bias settings for Mamba layers - chunk_size: Chunk size for Mamba processing - """ - - heterogeneous_layers_config_path: str = "" - """Path to the json file containing the heterogeneous block specs.""" - - heterogeneous_layers_config_encoded_json: str = "" - """The contents of the json file containing the heterogeneous block specs. It will be read from - heterogeneous_layers_config_path at first, then saved forever inside the model checkpoint.""" - - per_block_parameters: list[PuzzletronTransformerBlockConfig] = field(init=False) - """Configuration parameters for each of the transformer blocks in a - heterogeneous transformer.""" - - # all of these can be used to instantiate a MambaMixer, they are shared for all Mamba layers - d_conv: int = 4 - expand: int = 2 - D_has_hdim: bool = False - rmsnorm: bool = True - norm_before_gate: bool = False - dt_min: float = 0.001 - dt_max: float = 0.1 - dt_scale: float = 1.0 - bias: bool = False - conv_bias: bool = True - chunk_size: int = 128 - - def __post_init__(self) -> None: - if self.kv_channels is None and self.num_attention_heads == 0: - self.num_attention_heads = 8 # to avoid division by zero - # Type assertion to help mypy understand the type after the check - assert isinstance(self.num_attention_heads, int), "num_attention_heads must be an integer" - if self.heterogeneous_layers_config_encoded_json in ("", None): - assert self.heterogeneous_layers_config_path not in ( - None, - "", - ), ( - "heterogeneous_layers_config_path is required, if heterogeneous_layers_config_encoded_json is not provided" - ) - self.heterogeneous_layers_config_encoded_json = Path( - self.heterogeneous_layers_config_path - ).read_text() - hf_config_dict: dict[str, Any] = json.loads(self.heterogeneous_layers_config_encoded_json) - block_list = hf_config_dict["block_configs"] - # TODO: should we change the definition of num_layers? it can be sum(mlp/attention) rather than uneven blocks - if self.num_layers is None or self.num_layers == 0: - self.num_layers = len(block_list) - # Type assertion to help mypy understand the type after the check - assert isinstance(self.num_layers, int), "num_layers must be an integer" - assert self.num_layers == len(block_list), ( - "num_layers must match the number of blocks in the json file" - ) - super().__post_init__() - self.heterogeneous_block_specs = True - self.heterogeneous_dist_checkpoint = True # TODO: check if this is correct/needed - self.per_block_parameters = [ - PuzzletronTransformerBlockConfig.build_from_dict(block=block, main_config=self) - for block in block_list - ] - - # TODO add parallel blocks support - def get_config_for_layer( - self, layer_number: int - ) -> TransformerConfig | MLATransformerConfig | PuzzletronMambaTransformerConfig: - """ - Get the config for the given layer number. - Based on the layer number, the corresponding block config is returned, - overriding the main config's value. - - Returns: - TransformerConfig: For standard transformer layers - MLATransformerConfig: For MLA layers - PuzzletronMambaTransformerConfig: For Mamba layers - """ - layer_idx = layer_number - 1 # layer number starts from 1 - if layer_idx < 0 or layer_idx >= len(self.per_block_parameters): - raise ValueError( - f"Invalid layer number: {layer_number}. Should be in " - f"range [1, {len(self.per_block_parameters)}]." - ) - block_config = self.per_block_parameters[layer_idx] - - # Determine which config class to use based on the block configuration - if block_config.attention.is_mamba: - config_class = PuzzletronMambaTransformerConfig - elif block_config.attention.multi_latent_attention: - config_class = MLATransformerConfig - else: - config_class = TransformerConfig - - # Get all available fields from the attention and MLP configs - attention_fields = {f.name for f in fields(block_config.attention)} - mlp_fields = {f.name for f in fields(block_config.mlp)} - - # Get all available fields from the target config class - target_config_fields = {f.name for f in fields(config_class)} - - # Start with the base config - transformer_config_dict = asdict(self) - - # Remove keys that are not in the target config class - transformer_config_dict = { - k: v for k, v in transformer_config_dict.items() if k in target_config_fields - } - - # Update with all available attention config values (if they exist in target config) - for field_name in attention_fields: - if field_name in target_config_fields: - transformer_config_dict[field_name] = getattr(block_config.attention, field_name) - - # Update with all available MLP config values (if they exist in target config) - for field_name in mlp_fields: - if field_name in target_config_fields: - transformer_config_dict[field_name] = getattr(block_config.mlp, field_name) - - if transformer_config_dict["num_moe_experts"] is None: - # to pass __post_init__ of config_class - transformer_config_dict["expert_model_parallel_size"] = 1 - config = config_class(**transformer_config_dict) - - return config - - -# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/ba97a7e282a8478a02d012bc9b9e45f3a6be216e/megatron/core/extensions/transformer_engine.py#L449 -class WrappedTENormLinear(TELayerNormColumnParallelLinear): - """A wrapper around TELayerNormColumnParallelLinear with simplified interface and forced configurations. - - This wrapper simplifies the interface of TELayerNormColumnParallelLinear by: - 1. Taking only a config object instead of individual parameters - 2. Forcing specific configurations (tp_group=None, tp_size=1, etc.) for compatibility - 3. Adding version checks for Transformer Engine features - 4. Providing a cleaner interface for heterogeneous transformer models - - Key differences from TELayerNormColumnParallelLinear: - - Simplified constructor: only requires config and optional unused parameters - - Forces tensor parallel settings: tp_group=None, tp_size=1, tp_rank=0 - - Automatically sets input_size=output_size=config.hidden_size - - Adds version checks for TE features (delay_wgrad_compute, normalization, symmetric_ar_type) - - Forces bias=False, skip_bias_add=False for consistency - - Disables gather_output (raises error if True) - - Uses simplified init_method=lambda w: None - - This wrapper is designed for use in heterogeneous transformer architectures where - individual layers may have different configurations but need a consistent interface. - """ - - def __init__( - self, - config, - layer_number=None, # unused - model_comm_pgs=None, # unused - cp_comm_type=None, # unused - tp_group=None, # unused - tp_comm_buffer_name=None, - gather_output=False, # unused - ): - # unfortunately, TELayerNormColumnParallelLinear sets tp_group and forcing it to be None requires to copy/paste __init__ - if not HAVE_TE: - raise ImportError( - "Transformer Engine is not installed. " - "Please install it with `pip install transformer-engine`." - ) - - self.config = config - - if gather_output: - raise ValueError("Transformer Engine linear layers do not support gather_output = True") - - skip_bias_add = False - bias = False - - # TE returns a zero length Tensor when bias=False and - # return_bias=True, but we prefer None. So in that case we - # tell TE to not return the bias, and return None - # ourselves. This way our forward always returns two values - # and we don't have to deal with the zero length Tensor. - self.te_return_bias = skip_bias_add and bias - self.is_first_microbatch = True - self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache - extra_kwargs = _get_extra_te_kwargs(config) - self.tp_size = 1 - self.tp_rank = 0 - - if self.config.delay_wgrad_compute: - if is_te_min_version("2.3.0"): - extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute - else: - raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") - - # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` - if is_te_min_version("0.11.0"): - extra_kwargs["normalization"] = self.config.normalization - elif self.config.normalization != "LayerNorm": - te_version = get_te_version() - raise ValueError( - f"Transformer Engine v{te_version} does not support {self.config.normalization}." - ) - - if self.config.symmetric_ar_type is not None: - assert is_torch_min_version("2.7.0a0"), "Must have at least torch version 2.7 or higher" - assert is_te_min_version("2.3.0") or get_te_version() == PkgVersion( - "2.3.0.dev0+39c0e70" - ), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce" - extra_kwargs["symmetric_ar_type"] = self.config.symmetric_ar_type - - output_size = config.hidden_size - input_size = config.hidden_size - # calling te.pytorch.LayerNormLinear's __init__ - super(TELayerNormColumnParallelLinear, self).__init__( - in_features=input_size, - out_features=output_size, - eps=self.config.layernorm_epsilon, - sequence_parallel=self.config.sequence_parallel, - fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, - tp_group=None, - tp_size=1, - get_rng_state_tracker=( - get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None - ), - init_method=lambda w: None, - bias=bias, - return_bias=self.te_return_bias, - parallel_mode=None, - return_layernorm_output=False, - zero_centered_gamma=self.config.layernorm_zero_centered_gamma, - **extra_kwargs, - ) - - if config.use_cpu_initialization: - output_size_per_partition = output_size - _ = _initialize_affine_weight_cpu( - self.weight, - output_size, - input_size, - output_size_per_partition, - 0, - init_method=lambda w: None, - stride=1, - return_master_weight=False, - rank=self.tp_rank, - world_size=self.tp_size, - skip_set_tensor_parallel_attributes=True, - ) - if bias: - self.bias = Parameter( - torch.empty(output_size_per_partition, dtype=config.params_dtype) - ) - with torch.no_grad(): - self.bias.zero_() - - def forward(self, x, *args, **kwargs): - return super().forward(x) - - -class WrappedLinear(Linear): - def __init__( - self, - config, - layer_number=None, - model_comm_pgs=None, - cp_comm_type=None, - tp_group=None, - tp_comm_buffer_name=None, - gather_output=False, - ): - super().__init__( - input_size=config.hidden_size, - output_size=config.hidden_size, - config=config, - init_method=config.init_method, - bias=False, - gather_output=gather_output, - skip_bias_add=False, - tp_comm_buffer_name=tp_comm_buffer_name, - tp_group=tp_group, - ) - - def forward(self, x, *args, **kwargs): - return super().forward(x) - - -class WrappedTELinear(TELinear): - # TODO: docstring - def __init__( - self, - config, - layer_number=None, # unused - model_comm_pgs=None, # unused - cp_comm_type=None, # unused - tp_group=None, # unused - tp_comm_buffer_name=None, - gather_output=False, # unused - ): - super().__init__( - input_size=config.hidden_size, - output_size=config.hidden_size, - parallel_mode="duplicated", - # parallel_mode=None, - config=config, - init_method=config.init_method, - bias=False, - skip_bias_add=False, - skip_weight_param_allocation=False, - tp_comm_buffer_name=tp_comm_buffer_name, - is_expert=False, - ) - - def forward(self, x, *args, **kwargs): - return super().forward(x) - - -class WrappedMambaMixer(MambaMixer): - def __init__(self, *args, cp_comm_type: Optional[str] = None, **kwargs): - # ignoring cp_comm_type - super().__init__(*args, **kwargs) - - def forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - key_value_states: Optional[Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, - rotary_pos_cos: Optional[Tensor] = None, - rotary_pos_sin: Optional[Tensor] = None, - attention_bias: Optional[Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[int] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - ) -> Tuple[Tensor, Tensor]: - result = super().forward(hidden_states, inference_context=inference_context) - # Ensure we return a tuple of two tensors - assert isinstance(result, tuple) and len(result) == 2 - return result - - -# NOTE: new method -def get_layer_spec_for_layer( - block_params: PuzzletronTransformerBlockConfig, - config: PuzzletronHeterogeneousTransformerConfig, - use_transformer_engine: bool, - normalization: Optional[str] = None, - qk_l2_norm: Optional[bool] = False, -) -> ModuleSpec: - # this part is copied from megatron.core.models.gpt.gpt_layer_specs.get_gpt_decoder_block_spec() - if use_transformer_engine: - layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=block_params.mlp.num_moe_experts, - moe_grouped_gemm=False, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=block_params.attention.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - qk_l2_norm=qk_l2_norm, - use_kitchen=config.use_kitchen, - # use_te_activation_func=config.use_te_activation_func, #TODO: part of megatron 0.14 version. check if this is needed now. - ) - else: - layer_spec = get_gpt_layer_local_spec( - num_experts=block_params.mlp.num_moe_experts, - moe_grouped_gemm=False, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=block_params.attention.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - normalization=normalization, - qk_l2_norm=qk_l2_norm, - use_kitchen=config.use_kitchen, - ) - if block_params.attention.no_op: - layer_spec.submodules.input_layernorm = IdentityOp - layer_spec.submodules.self_attn_bda = IdentityFuncOp - layer_spec.submodules.self_attention = ModuleSpec(module=IdentityOp) - elif block_params.attention.replace_with_linear: - layer_spec.submodules.self_attention = ModuleSpec( - module=WrappedTENormLinear if use_transformer_engine else WrappedLinear, - params={"tp_comm_buffer_name": "linear_attn"}, - ) - elif block_params.attention.is_mamba: - mamba_mixer_params = dict( - d_model=config.hidden_size, - d_conv=config.d_conv, - expand=config.expand, - D_has_hdim=config.D_has_hdim, - rmsnorm=config.rmsnorm, - norm_before_gate=config.norm_before_gate, - dt_min=config.dt_min, - dt_max=config.dt_max, - dt_scale=config.dt_scale, - bias=config.bias, - conv_bias=config.conv_bias, - chunk_size=config.chunk_size, - ) - layer_spec.submodules.self_attention = ModuleSpec( - module=WrappedMambaMixer, - params=mamba_mixer_params, - submodules=MambaMixerSubmodules( - in_proj=( - TELayerNormColumnParallelLinear - if use_transformer_engine - else ColumnParallelLinear - ), - out_proj=TERowParallelLinear if use_transformer_engine else RowParallelLinear, - ), - ) - - if block_params.mlp.no_op: - layer_spec.submodules.pre_mlp_layernorm = IdentityOp - layer_spec.submodules.mlp_bda = IdentityFuncOp - layer_spec.submodules.mlp = ModuleSpec(module=IdentityOp) - elif block_params.mlp.replace_with_linear: - layer_spec.submodules.mlp = ModuleSpec( - module=WrappedTENormLinear if use_transformer_engine else WrappedLinear, - params={"tp_comm_buffer_name": "linear_mlp"}, - ) - - layer_spec.submodules.sharded_state_dict_keys_map = _get_sharded_state_dict_keys_map( - block_params, use_transformer_engine - ) - return layer_spec - - -# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py#L168 -def get_gpt_heterogeneous_layer_spec_puzzletron( - config: PuzzletronHeterogeneousTransformerConfig, - use_transformer_engine: bool, - normalization: Optional[str] = None, - qk_l2_norm: Optional[bool] = False, - vp_stage: Optional[int] = None, -) -> TransformerBlockSubmodules: - """Generate heterogeneous layer specifications for Puzzletron transformer models. - - This function is a specialized version of the original Megatron Core - `get_gpt_heterogeneous_layer_spec` function, adapted for Puzzletron's specific - heterogeneous transformer architecture requirements. - - Key differences from the original: - - **Signature**: Adds `normalization` and `qk_l2_norm` parameters, removes `pp_rank` - - **Architecture**: Uses `get_layer_spec_for_layer()` helper for modular layer creation - - **Pipeline Parallel**: Enhanced with `pipeline_model_parallel_layout` support - - **Configuration**: Uses `PuzzletronHeterogeneousTransformerConfig` with Mamba parameters - - **Layer Norm**: Simplified to `TENorm` vs `LNImpl` (removes `WrappedTorchNorm` complexity) - - **Features**: Supports Mamba layers, custom attention types, and advanced parallelization - - Args: - config: Puzzletron heterogeneous transformer configuration - use_transformer_engine: Whether to use Transformer Engine optimizations - normalization: Optional normalization type override - qk_l2_norm: Whether to apply L2 normalization to QK matrices - vp_stage: Virtual pipeline stage for advanced parallelization - - Returns: - TransformerBlockSubmodules: Complete layer specification for the heterogeneous model - """ - # Create the layer specs for the model. - layer_specs = [ - get_layer_spec_for_layer( - block_params, config, use_transformer_engine, normalization, qk_l2_norm - ) - for block_params in config.per_block_parameters - ] - - # Slice the layer specs to only include the layers that are built in this pipeline stage. - # Note: MCore layer_number starts at 1 - num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) - - if config.pipeline_model_parallel_layout is not None: - local_layer_specs = [ - layer_specs[layer_id] - for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list( - layer_type=LayerType.decoder, vp_stage=vp_stage - ) - ] - else: - offset = get_transformer_layer_offset(config, vp_stage=vp_stage) - local_layer_specs = layer_specs[offset : offset + num_layers_to_build] - - if use_transformer_engine: - layer_norm_impl = TENorm - else: - layer_norm_impl = LNImpl - - # Block spec. - block_spec = TransformerBlockSubmodules( - layer_specs=local_layer_specs, layer_norm=layer_norm_impl - ) - - return block_spec - - -# NOTE: based on https://github.com/NVIDIA/Megatron-LM/blob/aacc3b8aa5f0d3071431a94503d6233802fbaedd/gpt_builders.py#L23 -def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): - """Build a GPT model with Puzzletron's heterogeneous transformer architecture. - - This function is a specialized version of the original Megatron-LM `gpt_builder` function, - adapted for Puzzletron's heterogeneous transformer architecture requirements. - - Key differences from the original: - - **Simplified**: Focuses exclusively on heterogeneous models (rejects legacy, spec-based, MoE, MTP) - - **Configuration**: Only supports args-based config (removes YAML complexity) - - **Layer Spec**: Uses single `get_gpt_heterogeneous_layer_spec_puzzletron` function - - **Error Handling**: Explicit error messages for unsupported features - - **Logging**: Removes debug logging for cleaner implementation - - Args: - args: Command-line arguments namespace containing model configuration parameters - pre_process: Whether to include pre-processing layers - post_process: Whether to include post-processing layers - vp_stage: Virtual pipeline stage for advanced parallelization - config: Optional pre-configured transformer config (if None, created from args) - - Returns: - GPTModel: Configured GPT model with heterogeneous transformer architecture - - Raises: - ValueError: If legacy models, spec-based models, or MTP are requested - """ - assert config is not None, "config is required" - if args.use_legacy_models: - raise ValueError("Legacy models are not supported") - if args.spec is not None: - raise ValueError("Spec is not supported") - use_te = args.transformer_impl == "transformer_engine" - transformer_layer_spec = get_gpt_heterogeneous_layer_spec_puzzletron( - config, - use_te, - normalization=args.normalization, - qk_l2_norm=args.qk_l2_norm, - vp_stage=vp_stage, - ) - mtp_block_spec = None - if args.mtp_num_layers is not None: - raise ValueError("MTP is not supported") - model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=args.padded_vocab_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - rotary_base=args.rotary_base, - rope_scaling=args.use_rope_scaling, - mtp_block_spec=mtp_block_spec, - vp_stage=vp_stage, - ) - - return model diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py index 8a7c2834fd..73661edba5 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -38,7 +38,6 @@ from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, - DeciLMForCausalLM, DeciLMMultiDecoderLayer, DeciLMRMSNorm, LMHead, @@ -59,7 +58,6 @@ ) from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( - create_dummy_model, is_in_safetensors_format, load_and_shard_model, load_sharded_state_dict, diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index ecfb8b857b..c4d2ea054e 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -39,11 +39,7 @@ update_model_config, ) from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict -from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( - _save_checkpoint, - copy_deci_lm_hf_code, - load_model_config, -) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import _save_checkpoint, load_model_config from modelopt.torch.puzzletron.tools.logger import mprint from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config @@ -180,8 +176,6 @@ def init_child_from_parent( save_checkpoint_time = time.time() - start_time mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds") - copy_deci_lm_hf_code(output_checkpoint_dir) - # Print profiling summary with actual worker counts used total_core_time = create_child_state_dict_time + save_checkpoint_time actual_layer_workers = max_layer_workers if max_layer_workers else "auto" diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 3647de5e25..020afdfadd 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -23,7 +23,6 @@ import dataclasses import fcntl import os -import shutil import time import warnings from collections import defaultdict @@ -37,10 +36,7 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from modelopt.torch.puzzletron.decilm import deci_lm_hf_code from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import maybe_cast_block_configs -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from modelopt.torch.puzzletron.tools.common import infer_weights_dtype from modelopt.torch.puzzletron.tools.logger import mprint from modelopt.torch.puzzletron.tools.post_init_sparse import SparsityMethod @@ -69,54 +65,6 @@ warnings.filterwarnings("ignore", "You are using `torch.load` with `weights_only=False`*.") -def load_checkpoint( - checkpoint_dir: Path | str, - model_config_overrides: dict | None = None, - ignore_unexpected_config_keys: bool = False, - trust_remote_code: bool = False, -) -> DeciLMForCausalLM: - """ - Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your - local repo code, not the code inside the checkpoint. - - Args: - checkpoint_dir: Path to checkpoint directory - model_config_overrides: Optional mapping of config overrides. - ignore_unexpected_config_keys: If True, ignore unexpected config keys. - trust_remote_code: If True, allows execution of custom code from the model repository. - This is a security risk if the model source is untrusted. Only set to True if you - trust the source of the model. Defaults to False for security. - """ - from modelopt.torch.puzzletron.tools.checkpoint_utils import ( - load_state_dict, # prevent circular import - ) - - if not isinstance(checkpoint_dir, Path): - checkpoint_dir = Path(checkpoint_dir) - - model_config = load_model_config( - checkpoint_dir, - model_config_overrides=model_config_overrides, - ignore_unexpected_config_keys=ignore_unexpected_config_keys, - trust_remote_code=trust_remote_code, - ) - - # Without sparsity we could have done: - # model = DeciLMForCausalLM.from_pretrained(pretrained_model_name_or_path=checkpoint_dir, config=model_config) - state_dict = load_state_dict(checkpoint_dir) - state_dict, sparsity_masks = SparsityMethod.fix_state_dict_inplace(state_dict, verbose=True) - dtype = infer_weights_dtype(state_dict) - model = DeciLMForCausalLM.from_pretrained( - pretrained_model_name_or_path=None, - config=model_config, - state_dict=state_dict, - torch_dtype=dtype, - ) - SparsityMethod().apply_masks(model, sparsity_masks) - - return model - - def force_cache_dynamic_modules( config: PretrainedConfig, checkpoint_dir: Path | str, trust_remote_code: bool = False ): @@ -233,33 +181,6 @@ def _save_checkpoint( ) -def split_checkpoint_to_subblocks( - checkpoint_dir: Path | str, trust_remote_code: bool = False -) -> None: - """Split a checkpoint into subblocks. - - Args: - checkpoint_dir: Path to checkpoint directory - trust_remote_code: If True, allows execution of custom code from the model repository. - This is a security risk if the model source is untrusted. Only set to True if you - trust the source of the model. Defaults to False for security. - """ - from modelopt.torch.puzzletron.tools.checkpoint_utils import ( - load_state_dict, # prevent circular import - ) - - if not isinstance(checkpoint_dir, Path): - checkpoint_dir = Path(checkpoint_dir) - - model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) - state_dict = load_state_dict(checkpoint_dir) - save_subblocks(state_dict, checkpoint_dir) - - if (index_path := checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists(): - index_path.rename(checkpoint_dir / f"before_splitting.{SAFE_WEIGHTS_INDEX_NAME}") - save_safetensors_index(model_config, checkpoint_dir) - - def save_subblocks( state_dict: dict[str, torch.Tensor], checkpoint_dir: Path | str, @@ -374,51 +295,6 @@ def optimized_safe_save(kwargs): mprint(f" Save operation was {save_time / subblocks_total_time * 100:.1f}% of total time") -def save_safetensors_index( - model_config: DeciLMConfig, - checkpoint_dir: Path | str, -) -> None: - """Save safetensors index for DeciLM models (legacy function).""" - mprint("=== Starting save_safetensors_index profiling ===") - index_start_time = time.time() - - if not isinstance(checkpoint_dir, Path): - checkpoint_dir = Path(checkpoint_dir) - - # Step 1: Create fake model on meta device - fake_model_start_time = time.time() - with torch.device("meta"): - fake_model = DeciLMForCausalLM(model_config) - fake_model_time = time.time() - fake_model_start_time - mprint(f" Step 1 - Create fake model: {fake_model_time:.2f}s") - - # Step 2: Build weight map - weight_map_start_time = time.time() - weight_map = _build_safetensors_weight_map( - state_dict=fake_model.state_dict(), - non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, - module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, - layers_module_name=LAYERS_MODULE_NAME, - ) - weight_map_time = time.time() - weight_map_start_time - mprint(f" Step 2 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") - - # Step 3: Create and write index - write_start_time = time.time() - index = {"metadata": {"format": "pt"}, "weight_map": weight_map} - index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME - index_json = json_dumps(index) - _write_file_process_safe(index_json, index_path) - write_time = time.time() - write_start_time - mprint(f" Step 3 - Write index file: {write_time:.2f}s ({len(index_json)} chars)") - - index_total_time = time.time() - index_start_time - mprint(f"=== save_safetensors_index completed in {index_total_time:.2f}s ===") - mprint( - f" Breakdown: FakeModel {fake_model_time:.1f}s + WeightMap {weight_map_time:.1f}s + Write {write_time:.1f}s" - ) - - def _write_text(content: str, f: BinaryIO) -> None: f.write(content.encode("utf-8")) @@ -499,14 +375,3 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str for conf in model_config.block_configs ] model_config.save_pretrained(checkpoint_dir) - - -def copy_deci_lm_hf_code(output_dir: Path | str) -> None: - """ - Copy the deci_lm_hf_code directory to the output directory. - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - code_dir = Path(deci_lm_hf_code.__file__).parent - for path in code_dir.glob("*.py"): - shutil.copy(path, output_dir / path.name) diff --git a/modelopt/torch/puzzletron/tools/post_init_sparse.py b/modelopt/torch/puzzletron/tools/post_init_sparse.py index e2c45c4030..eb20250e68 100644 --- a/modelopt/torch/puzzletron/tools/post_init_sparse.py +++ b/modelopt/torch/puzzletron/tools/post_init_sparse.py @@ -17,8 +17,6 @@ from torch import nn from torch.nn.utils.prune import custom_from_mask -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM - """ Converts a state dictionary from PyTorch's pruning format (with _orig and _mask suffixes) into a standard format with sparsified weights. @@ -61,7 +59,7 @@ def apply_masks(self, model: nn.Module, mask_dict: dict[str, torch.Tensor]) -> N print(name) print(torch.sum(mask_dict[name]) / mask_dict[name].numel()) - def do_sparsity(self, model: DeciLMForCausalLM, mask_dict=None): + def do_sparsity(self, model: nn.Module, mask_dict=None): full_name_layers = [] for block_idx, block_config in enumerate(model.config.block_configs): ffn_names = block_config.ffn.sparsify # layers_to_sparsify_pattern[block_idx] diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py index 55926eaaea..c18867a576 100644 --- a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -40,15 +40,8 @@ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files -from typing_extensions import override import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( - DeciLMDecoderLayer, - DeciLMForCausalLM, - rope_type_to_class, -) from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch.puzzletron.tools.logger import mprint from modelopt.torch.puzzletron.utils.dummy_modules import ( @@ -60,51 +53,6 @@ from modelopt.torch.puzzletron.utils.utils import EmptyInitOnDevice -class DeciLMDummyBlock(DummyModule): - """Dummy block for DeciLM models (used by replacement_library).""" - - def __init__(self, config: DeciLMConfig, block_index: int): - super().__init__() - self.config = config - self.block_index = block_index - - @override - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor | tuple[torch.Tensor, None]: - if self.config.block_return_only_hidden_states: - return x - else: - return x, None - - -class DeciLMDummyWTE(DummyModule): - """Dummy word token embedding for DeciLM models (used by replacement_library).""" - - def __init__(self, config: DeciLMConfig, dtype: torch.dtype | None = None): - super().__init__() - self.n_embd = config.get_hidden_size() - self.dtype = dtype - - @override - def forward(self, input_ids: torch.Tensor) -> torch.Tensor: - B, T = input_ids.shape # noqa: N806 - result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device) - return result - - -class DeciLMDummyLMHead(DummyModule): - """Dummy LM head for DeciLM models (used by replacement_library).""" - - def __init__(self, config: DeciLMConfig): - super().__init__() - self.vocab_size = config.vocab_size - - @override - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, T, C = x.shape # noqa: N806 - result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device) - return result - - def set_submodule(model: nn.Module, module_name: str, new_submodule: nn.Module) -> None: """Set a submodule on a model by dotted path.""" parts = module_name.split(".") @@ -149,26 +97,6 @@ def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtim return model -def create_dummy_model( - model_config: DeciLMConfig, - dtype: torch.dtype, -) -> DeciLMForCausalLM: - with torch.device("meta"): - model = DeciLMForCausalLM(model_config) - - rope_cls = rope_type_to_class[model_config.position_embedding_type] - model.model.rotary_emb = rope_cls(config=model.config) - - model.model.set_input_embeddings(DeciLMDummyWTE(model.config, dtype)) - model.model.set_final_layer_norm(nn.Identity()) - model.set_output_embeddings(DeciLMDummyLMHead(model.config)) - - for block_index in range(model_config.get_num_hidden_layers()): - model.model.layers[block_index] = DeciLMDummyBlock(model.config, block_index) - - return model - - def _get_model_class_from_config(config: PretrainedConfig): """ Get the model class from config.architectures field. diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index d253c94457..6bf966a2ae 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -35,7 +35,6 @@ import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.anymodel.converter import Converter from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.replacement_library.replacement_library import ReplacementLibrary from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement from modelopt.torch.puzzletron.tools import validate_model @@ -43,10 +42,7 @@ SAFETENSORS_SUBBLOCKS_DIR_NAME, copy_tokenizer, ) -from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( - save_checkpoint, - save_safetensors_index, -) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model from modelopt.torch.puzzletron.tools.validation_utils import ( validate_model_and_extract_hidden_states, @@ -190,6 +186,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None: Converter.copy_checkpoint_files(args.teacher_dir, checkpoint_dir) if realizable_as_symlinks: if dist.is_master(): + # TODO: Loo into internal Puzzleron code to see how to save as symlinks # save_checkpoint_as_symlinks is currently not supported pass save_checkpoint(model, checkpoint_dir, descriptor) @@ -230,39 +227,6 @@ def can_realize_as_symlinks(layer_replacements: list[dict]) -> bool: return True -def force_create_symlink(src: Path, dst: Path) -> None: - if dst.exists(): - dst.unlink() - dst.symlink_to(src) - - -def save_checkpoint_as_symlinks( - layer_replacements: list[dict], - model_config: DeciLMConfig, - checkpoint_dir: Path, - replace_library: ReplacementLibrary, -) -> None: - model_config.save_pretrained(checkpoint_dir) - (checkpoint_dir / "subblocks_safetensors").mkdir(parents=True, exist_ok=True) - save_safetensors_index(model_config, checkpoint_dir) - - for layer_replacement in layer_replacements: - for weight_path in layer_replacement["weight_paths"]: - force_create_symlink( - weight_path, checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME / weight_path.name - ) - - lm_head_path = replace_library.get_teacher_lm_head_path() - force_create_symlink( - lm_head_path, checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME / lm_head_path.name - ) - - embedding_path = replace_library.get_teacher_embedding_path() - force_create_symlink( - embedding_path, checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME / embedding_path.name - ) - - def _load_tokenizer(args: DictConfig) -> PreTrainedTokenizerBase: tokenizer = None if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: From 419027593a8583e3faabda41767350e2f6aebcb3 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Mar 2026 18:24:48 +0100 Subject: [PATCH 47/62] Dkorzekwa/decilm hf code cleanup 2 (#1073) ### What does this PR do? Delete not used decilm code ## Summary by CodeRabbit * **Refactor** * Removed DeciLM-specific components including decoder layers, attention implementations, and specialized cache utilities, streamlining the codebase * Updated replacement library to use generic model configurations instead of DeciLM-specific types, improving compatibility with diverse architectures * Cleaned up internal utilities for attention masking, flash attention compatibility, and rotary position embeddings --------- Signed-off-by: Daniel Korzekwa --- .../deci_lm_hf_code/configuration_decilm.py | 2 - .../megatron_lm__mamba_mixer.py | 527 ---- .../decilm/deci_lm_hf_code/modeling_decilm.py | 1553 +--------- ...ormers_4_44_2__modeling_attn_mask_utils.py | 498 ---- ...g_flash_attention_utils_backward_compat.py | 363 --- .../transformers_4_44_2__modeling_outputs.py | 1768 ------------ .../transformers_4_44_2__pytorch_utils.py | 32 - .../transformers_4_51_3__cache_utils.py | 2535 ----------------- ...rmers_4_51_3__modeling_llama4_attention.py | 289 -- .../decilm/deci_lm_hf_code/variable_cache.py | 213 -- .../decilm/deci_lm_hf_code/vllm_yarn_utils.py | 210 -- .../replacement_library.py | 258 +- .../replacement_library/replacement_utils.py | 9 +- 13 files changed, 18 insertions(+), 8239 deletions(-) delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py index 6ff0e26a4e..34a7e8cfcf 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py @@ -27,13 +27,11 @@ # fakes imports to make AutoConfig infer dependencies from .transformers_4_44_2__modeling_rope_utils import rope_config_validation -from .transformers_4_51_3__cache_utils import HybridChunkedCache from .transformers_4_51_3__configuration_llama4 import Llama4Config # make sure that auto-formatting doesn't remove the fake imports rope_config_validation Llama4Config -HybridChunkedCache class DeciLMConfig(LlamaConfig): diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py deleted file mode 100644 index 76dbb3473b..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py +++ /dev/null @@ -1,527 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright (c) 2024, Tri Dao, Albert Gu. - -# Adapted from megatron.core.ssm.mamba_mixer.MambaMixer: -# https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/0b5140009fb9011eceaef6d36ea1181a8d176479/megatron/core/ssm/mamba_mixer.py - -# ruff: noqa: N803, N806 - -# Some of this code was adopted from https://github.com/state-spaces/mamba/ -# This source code is licensed under the Apache license found in the -# LICENSE file in the root directory of this source tree. - -import math -import warnings - -import torch -import torch.nn as nn -import torch.nn.functional as F - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - from einops import rearrange, repeat - from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import ( - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - ) - - class MambaMixerMegatron(nn.Module): - """ - Args: - d_model: The hidden size of the model. - d_state: The state size of the SSM. - d_conv: The number of channels in the causal convolution. - conv_init: The initialization range for the causal convolution weights. - nheads: The number of Mamba heads. Used to calculate the expansion factor for the SSM - instead of the deprecated arg "expand". - headdim: The hidden size of each attention head. - ngroups: The number of attention heads. - A_init_range: The initialization range for the attention weights. - D_has_hdim: Whether the D parameter has the same number of dimensions as the hidden - state. - rmsnorm: Whether to use root mean square normalization. - norm_before_gate: Whether to apply normalization before the gating mechanism. - dt_min: The minimum value of the dt parameter. - dt_max: The maximum value of the dt parameter. - dt_init: The initialization value of the dt parameter. - dt_scale: The scaling factor for the dt parameter. - dt_init_floor: The minimum value of the dt parameter after initialization. - bias: Whether to use bias in the linear layers. - conv_bias: Whether to use bias in the causal convolution. - chunk_size: The chunk size for the fused kernel. - use_mem_eff_path: Whether to use the memory-efficient path for the Mamba model. - layer_number: The layer number of this Mamba layer. - """ - - def __init__( - self, - d_model, - d_state=256, - d_conv=4, - conv_init=None, - nheads=256, - headdim=64, - ngroups=8, - A_init_range=(1, 16), - D_has_hdim=False, - rmsnorm=True, - norm_before_gate=False, - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - bias=False, - conv_bias=True, - # Fused kernel and sharding options - chunk_size=128, - use_mem_eff_path=True, - layer_number=None, - ): - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.conv_init = conv_init - self.nheads = nheads - self.headdim = headdim - self.ngroups = ngroups - self.D_has_hdim = D_has_hdim - self.rmsnorm = rmsnorm - self.norm_before_gate = norm_before_gate - self.chunk_size = chunk_size - self.use_mem_eff_path = use_mem_eff_path - self.layer_number = layer_number - - self.d_inner = self.nheads * self.headdim - - self.tensor_model_parallel_size = 1 - assert self.d_inner % self.tensor_model_parallel_size == 0 - assert self.ngroups % self.tensor_model_parallel_size == 0 - assert self.nheads % self.tensor_model_parallel_size == 0 - assert not bias - assert not self.norm_before_gate - - self.d_inner_local = self.d_inner // self.tensor_model_parallel_size - self.ngroups_local = self.ngroups // self.tensor_model_parallel_size - self.nheads_local = self.nheads // self.tensor_model_parallel_size - - assert self.d_inner_local % self.ngroups_local == 0 - - # Assume sequence parallelism: input is already partitioned along the - # sequence dimension - self.in_proj = nn.Linear( - self.d_model, - self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads, # AB CD E - bias=False, - ) - - conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state # A CD - - # weight dim: [conv_dim, conv_dim, d_conv] - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - ) - - if self.conv_init is not None: - nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) - - self.activation = "silu" - self.act = nn.SiLU() - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.nheads_local) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - self.dt_bias = nn.Parameter(inv_dt) - # Our initialization would set all Linear.bias to zero, - # need to mark this one as _no_reinit - self.dt_bias._no_reinit = True - # Just to be explicit. Without this we already don't - # put wd on dt_bias because of the check - - # name.endswith("bias") in param_grouping.py - self.dt_bias._no_weight_decay = True - - assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] - A = torch.empty(self.nheads_local, dtype=torch.float32).uniform_(*A_init_range) - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter( - torch.ones( - self.d_inner_local if self.D_has_hdim else self.nheads_local, - ) - ) # Keep in fp32 - self.D._no_weight_decay = True - - if self.rmsnorm: - self.norm = RMSNormGated( - self.d_inner_local, - eps=1e-5, - group_size=self.d_inner_local // self.ngroups_local, - norm_before_gate=self.norm_before_gate, - ) - - # Assume sequence parallelism: input is partitioned along d_inner and - # output is partitioned along the sequence dimension - self.out_proj = nn.Linear( - self.d_inner, - self.d_model, - bias=False, - ) - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (nL, B, D) / (L B D) - Returns: same shape as hidden_states - """ - _, batch, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - # assert not self.config.sequence_parallel - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, out_bias, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out, out_bias - - # (nheads_local) - A = -torch.exp(self.A_log.float()) - - # xz, _ = self.in_proj(hidden_states) # TransformerEngine also returns bias - xz = self.in_proj(hidden_states) - - # transpose: l b pd --> b l pd - xz = rearrange(xz, "l b d -> b l d").contiguous() - - if self.use_mem_eff_path and inference_params is None: - assert ssm_state is None - - if self.conv1d.bias is not None: - self.conv1d.bias.data_ptr() - - y = mamba_split_conv1d_scan_combined( - xz, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.dt_bias.float(), - A, - D=( - rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.D - ), - chunk_size=self.chunk_size, - activation=self.activation, - headdim=None if self.D_has_hdim else self.headdim, - ngroups=self.ngroups_local, - norm_before_gate=self.norm_before_gate, - ) - - if self.rmsnorm: - y = self.norm(y) - else: - z, xBC, dt = torch.split( - xz, - [ - self.d_inner_local, - self.d_inner_local + 2 * self.ngroups_local * self.d_state, - self.nheads_local, - ], - dim=-1, - ) - - # transpose: b l pd --> b pd l - xBC = rearrange(xBC, "b l d -> b d l").contiguous() - - # Compute short convolution - if conv_state is not None: - # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state.copy_( - F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) - ) # Update state (B D W) - - seqlen = xBC.size(2) - if causal_conv1d_fn is None: - xBC = self.act(self.conv1d(xBC)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - xBC = causal_conv1d_fn( - x=xBC, - weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, - activation=self.activation, - ) - - # transpose b pd l --> b l pd - xBC = rearrange(xBC, "b d l -> b l d").contiguous() - - x, B, C = torch.split( - xBC, - [ - self.d_inner_local, - self.ngroups_local * self.d_state, - self.ngroups_local * self.d_state, - ], - dim=-1, - ) - - # TO DO Vijay: fuse most of the transposes with the GEMMS - x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim).contiguous() - dt = dt.contiguous() - B = rearrange(B, "b l (g n) -> b l g n", n=self.d_state).contiguous() - C = rearrange(C, "b l (g n) -> b l g n", n=self.d_state).contiguous() - z = rearrange(z, "b l (h p) -> b l h p", p=self.headdim).contiguous() - y = mamba_chunk_scan_combined( - x, - dt, - A, - B, - C, - self.chunk_size, - D=( - rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.D - ), - z=z if not self.rmsnorm else None, - dt_bias=self.dt_bias.float(), - dt_softplus=True, - return_final_states=ssm_state is not None, - ) - - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - - if self.rmsnorm: - y = rearrange(y, "b l h p -> b l (h p)").contiguous() - z = rearrange(z, "b l h p -> b l (h p)").contiguous() - y = self.norm(y, z) - else: - y = rearrange(y, "b l h p -> b l (h p)").contiguous() - - y = rearrange(y, "b l d -> l b d").contiguous() - # out, out_bias = self.out_proj(y) # TransformerEngine also returns bias - out = self.out_proj(y) - - return out - - def step(self, hidden_states, conv_state, ssm_state): - """ - Performs inference step for decoding - """ - # assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now" - dtype = hidden_states.dtype - assert hidden_states.shape[0] == 1, ( - "Only support decoding with 1 token at a time for now" - ) - - # l b d --> b d - hidden_states = hidden_states.squeeze(0) - - # b d_model --> b p(2d) - xz, _ = self.in_proj(hidden_states) - - z, xBC, dt = torch.split( - xz, - [ - self.d_inner_local, - self.d_inner_local + 2 * self.ngroups_local * self.d_state, - self.nheads_local, - ], - dim=-1, - ) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum( - conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 - ) # (B D) - if self.conv1d.bias is not None: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(dtype=dtype) - else: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x, B, C = torch.split( - xBC, - [ - self.d_inner_local, - self.ngroups_local * self.d_state, - self.ngroups_local * self.d_state, - ], - dim=-1, - ) - A = -torch.exp(self.A_log.float()) - - # SSM step - if selective_state_update is None: - if self.ngroups_local > 1: - B = rearrange(B, "b (g n) -> b g n", n=self.d_state) - C = rearrange(C, "b (g n) -> b g n", n=self.d_state) - B = repeat(B, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) - C = repeat(C, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) - - dt = repeat(dt, "b h -> b (h p)", p=self.headdim) - dt_bias = repeat(self.dt_bias, "h -> (h p)", p=self.headdim) - A = repeat(A, "h -> (h p) n", p=self.headdim, n=self.d_state) - D = repeat(self.D, "h -> (h p)", p=self.headdim) - - dt = F.softplus(dt + dt_bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - - dB_x = torch.einsum("bd,bdn,bd->bdn", dt, B, x) - ssm_state.copy_( - ssm_state * rearrange(dA, "b (h p) n -> b h p n", p=self.headdim) - + rearrange(dB_x, "b (h p) n -> b h p n", p=self.headdim) - ) - - y = torch.einsum( - "bdn,bdn->bd", - rearrange(ssm_state.to(dtype), "b h p n -> b (h p) n", p=self.headdim), - C, - ) - y = y + D.to(dtype) * x - if not self.rmsnorm: - y = y * self.act(z) # (B D) - else: - # Discretize A and B (b (g n)) - dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) - dA = torch.exp(dt * A) - x = rearrange(x, "b (h p) -> b h p", p=self.headdim) - dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) - ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) - y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) - y = y + rearrange(self.D.to(dtype), "h -> h 1") * x - y = rearrange(y, "b h p -> b (h p)") - if not self.rmsnorm: - y = y * self.act(z) # (B D) - else: - A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) - dt = repeat(dt, "b h -> b h p", p=self.headdim) - dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) - D = repeat(self.D, "h -> h p", p=self.headdim) - B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local) - C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local) - x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) - if not self.rmsnorm: - z = rearrange(z, "b (h p) -> b h p", p=self.headdim) - y = selective_state_update( - ssm_state, - x_reshaped, - dt, - A, - B, - C, - D, - z=z if not self.rmsnorm else None, - dt_bias=dt_bias, - dt_softplus=True, - ) - y = rearrange(y, "b h p -> b (h p)") - - if self.rmsnorm: - y = self.norm(y, z) - - # b pd --> b d - out, out_bias = self.out_proj(y) - return out.unsqueeze(0), out_bias, conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): - """ - allocate inference cache - """ - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.conv1d.weight.shape[0], - self.d_conv, - device=device, - dtype=conv_dtype, - ) - ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, - self.nheads_local, - self.headdim, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_number is not None - if self.layer_number not in inference_params.key_value_memory_dict: - conv_state = torch.zeros( - batch_size, - self.conv1d.weight.shape[0], - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.nheads_local, - self.headdim, - self.d_state, - device=self.in_proj.weight.device, - dtype=self.in_proj.weight.dtype, - ) - inference_params.key_value_memory_dict[self.layer_number] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_number] - # TO DO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - -except ImportError as exception: - mamba_error_message = f"Cannot declare MambaMixer due to missing dependencies: {exception=}." - warnings.warn(mamba_error_message) - - # TODO: Investigate why this type ignore is needed - class MambaMixerMegatron(nn.Module): # type: ignore[no-redef] - def __init__(self, *args, **kwargs): - raise ImportError(mamba_error_message) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py index 84496bc4a3..0102fc3a95 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py @@ -15,123 +15,19 @@ # Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved. # -# This code for Nvidia's model is based on the Llama modeling code by HuggingFace, -# which is in turn based on EleutherAI's GPT-NeoX library and the GPT-NeoX and -# OPT implementations in this library. -# Sliding window code based on Gemma2 by Google. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Pared-down DeciLM building blocks for Model-Optimizer puzzletron / AnyModel flows. +# The full HF DeciLM decoder stack (decoder layers, attention, rope, etc.) is not vendored here; +# AnyModel loads real models via transformers. This module keeps shared helpers: RMSNorm, +# gated/vanilla MLP (used by MoE accounting), MoE, and LMHead for replacement / validation code. # mypy: ignore-errors -import math - import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn -from transformers.utils import is_flash_attn_greater_or_equal_2_10, logging -from .block_config import AttentionConfig, FFNConfig, MambaConfig, MoEConfig +from .block_config import FFNConfig, MoEConfig from .configuration_decilm import DeciLMConfig -from .megatron_lm__mamba_mixer import MambaMixerMegatron from .transformers_4_44_2__activations import ACT2FN -from .transformers_4_44_2__cache_utils import Cache, StaticCache -from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter -from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import ( - _flash_attention_forward, -) -from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS -from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS -from .transformers_4_51_3__modeling_llama4_attention import Llama4TextAttention, Llama4TextConfig -from .variable_cache import VariableCache -from .vllm_yarn_utils import YaRNScalingRotaryEmbedding - -# from transformers.models.llama4.modeling_llama4 import Llama4TextL2Norm -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "DeciLMConfig" - - -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or - a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be - as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -class Llama4TextL2Norm(torch.nn.Module): - def __init__(self, eps: float = 1e-6): - super().__init__() - self.eps = eps - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - return self._norm(x.float()).type_as(x) - - def extra_repr(self): - return f"eps={self.eps}" class DeciLMRMSNorm(nn.Module): @@ -154,349 +50,10 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -ALL_LAYERNORM_LAYERS.append(DeciLMRMSNorm) - - -class DeciLMRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: DeciLMConfig | None = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`DeciLMRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get( - "rope_type", config.rope_scaling.get("type") - ) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_impl = "rope" if config is None else config.position_embedding_type - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - def _set_inv_freq_if_needed(self, device: torch.device) -> None: - is_missing_inv_freq = not hasattr(self, "inv_freq") - is_meta_mismatch = not is_missing_inv_freq and ( - str(device) != "meta" and self.inv_freq.is_meta - ) - - if is_missing_inv_freq or is_meta_mismatch: - with torch.device(device): - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, **self.rope_kwargs - ) - self.original_inv_freq = inv_freq - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer( - "inv_freq", inv_freq, persistent=False - ) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if ( - seq_len < self.original_max_seq_len - and self.max_seq_len_cached > self.original_max_seq_len - ): # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - self._set_inv_freq_if_needed(x.device) - - if self.rope_impl == "rope_llama4": - return self.llama4_forward(x, position_ids) - else: - return self.llama3_forward(x, position_ids) - - def llama3_forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = ( - device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def llama4_forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - # Core RoPE block - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = ( - device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) - freqs_cis = torch.polar( - torch.ones_like(freqs), freqs - ) # Convert to complex representation - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - freqs_cis = freqs_cis * self.attention_scaling - return freqs_cis - - -class DeciMistralYarnRotaryEmbedding(nn.Module): - def __init__(self, config: DeciLMConfig): - super().__init__() - self.config = config - self.rope_scaling = config.rope_scaling - self.base = config.rope_theta - self.rope_impl = config.position_embedding_type - self.head_size = config.hidden_size // config.num_attention_heads - self.yarn = YaRNScalingRotaryEmbedding( - head_size=self.head_size, - rotary_dim=self.head_size, - max_position_embeddings=self.rope_scaling["original_max_position_embeddings"], - base=self.base, - is_neox_style=True, - scaling_factor=self.rope_scaling["factor"], - beta_fast=self.rope_scaling["beta_fast"], - beta_slow=self.rope_scaling["beta_slow"], - dtype=torch.float32, - ) - self.attention_scaling = self.yarn.mscale - self.scaling_factor = self.rope_scaling["factor"] - self.rope_impl = "rope" if config is None else config.position_embedding_type - self.rope_impl = "even_odd" - - def _set_inv_freq_if_needed(self, device: torch.device) -> None: - is_missing_inv_freq = not hasattr(self, "inv_freq") - is_meta_mismatch = not is_missing_inv_freq and ( - str(device) != "meta" and self.inv_freq.is_meta - ) - - if is_missing_inv_freq or is_meta_mismatch: - with torch.device(device): - inv_freq = self.yarn._compute_inv_freq(self.scaling_factor) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def halves_forward(self, x, position_ids): - device_type = x.device.type - device_type = ( - device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - ) - - self._set_inv_freq_if_needed(x.device) - - # print(f"halves_forward") - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - inv_freq_expanded = inv_freq_expanded.to(x.device) - # print(f"inv_freq_expanded: {inv_freq_expanded.device}") - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def forward(self, x, position_ids): - if self.rope_impl == "halves": - return self.halves_forward(x, position_ids) - elif self.rope_impl == "even_odd": - return self.even_odd_forward(x, position_ids) - else: - raise ValueError(f"Invalid rope implementation: {self.rope_impl}") - - def even_odd_forward(self, x, position_ids): - device_type = x.device.type - device_type = ( - device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - ) - - self._set_inv_freq_if_needed(x.device) - - # print(f"even_odd_forward") - # Core RoPE block - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) - freqs_cis = torch.polar( - torch.ones_like(freqs), freqs - ) # Convert to complex representation - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - freqs_cis = freqs_cis * self.attention_scaling - return freqs_cis - - -class DeciLMLinearScalingRotaryEmbedding(DeciLMRotaryEmbedding): - """DeciLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`DeciLMLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " - "`DeciLMRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -class DeciLMDynamicNTKScalingRotaryEmbedding(DeciLMRotaryEmbedding): - """DeciLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`DeciLMDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " - "`DeciLMRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - -rope_type_to_class = { - "default": DeciLMRotaryEmbedding, - "linear": DeciLMLinearScalingRotaryEmbedding, - "dynamic": DeciLMDynamicNTKScalingRotaryEmbedding, - "rope_llama4": DeciLMRotaryEmbedding, - "rope": DeciLMRotaryEmbedding, - "mistral_yarn": DeciMistralYarnRotaryEmbedding, -} - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, freqs_cis, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - freqs_cis (`torch.Tensor`): The frequency tensor. - a tuple of two tensors, cos and sin. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - # print(f"applying first half-second half") - cos, sin = freqs_cis - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def vllm_apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - # print(f"freqs_cis: {freqs_cis.shape}, xq_: {xq_.shape}, xk_: {xk_.shape}") - xq_out = torch.view_as_real(xq_ * freqs_cis[:, None, :, :]).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis[:, None, :, :]).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) +def sparsity_backward_hook(*args, **kwargs): + raise NotImplementedError( + "No support for sparsity when training HF DeciLM (inference is ok though)" + ) class DeciLMGatedMLP(nn.Module): @@ -545,1040 +102,6 @@ def forward(self, x): return down_proj -class DeciLMVanillaMLP(nn.Module): - def __init__( - self, - config: DeciLMConfig, - ffn_config: FFNConfig, - ): - super().__init__() - self.config = config - self.ffn_config = ffn_config - self.hidden_size = config.hidden_size - self.intermediate_size = ffn_config.intermediate_size - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[getattr(ffn_config, "hidden_act", "silu")] - - if ffn_config.sparsify is not None: - self.register_full_backward_hook(sparsity_backward_hook) - - assert self.config.pretraining_tp == 1, ( - "Unsupported pretraining_tp != 1 for DeciLMVanillaMLP" - ) - - def forward(self, x): - return self.down_proj(self.act_fn(self.up_proj(x))) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class DeciLMAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: DeciLMConfig, - attention_config: AttentionConfig, - layer_idx: int | None = None, - ): - super().__init__() - self.config = config - self.attention_config = attention_config # type: AttentionConfig - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - if config.head_dim is not None: - self.head_dim = config.head_dim - else: - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_groups = attention_config.n_heads_in_group # DeciLM-specific code - self.num_key_value_heads = ( - self.num_heads // self.num_key_value_groups - ) # DeciLM-specific code - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - # llama4 attention specific - self.llama4_attn_config = attention_config.llama4 - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=config.o_proj_bias - ) - - if self.config.position_embedding_type in ["rope", "rope_llama4", "mistral_yarn"]: - # TO DO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = rope_type_to_class[self.config.position_embedding_type]( - config=self.config - ) - - if attention_config.sparsify is not None: - self.register_full_backward_hook(sparsity_backward_hook) - - self.is_llama4 = self.llama4_attn_config is not None - if ( - self.is_llama4 - and self.llama4_attn_config.use_qk_norm - and self.llama4_attn_config.use_rope - ): - self.qk_norm = Llama4TextL2Norm(self.config.rms_norm_eps) - - self.use_rope = ( - self.llama4_attn_config.use_rope - if self.is_llama4 - else self.config.position_embedding_type in ["rope", "mistral_yarn"] - ) - self.rope_impl = self.rotary_emb.rope_impl - self.apply_rope_fn = ( - apply_rotary_emb - if self.rope_impl in ["even_odd", "rope_llama4"] - else apply_rotary_pos_emb - ) - # self.apply_rope_fn = apply_rotary_emb - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] - | None = None, # will become mandatory in v4.45 - **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - bsz, q_len, _ = hidden_states.size() - input_shape = hidden_states.shape[:-1] - - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( - 1, 2 - ) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - if self.use_rope: - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE " - "embeddings internally through `position_ids` (2D tensor with the indexes of the " - "tokens), to using externally computed `position_embeddings` (Tuple of tensors, " - "containing cos and sin). In v4.45 `position_ids` will be removed and " - "`position_embeddings` will be mandatory." - ) - freqs_cis = self.rotary_emb(value_states, position_ids) - else: - freqs_cis = position_embeddings - - query_states, key_states = self.apply_rope_fn(query_states, key_states, freqs_cis) - - if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm - query_states = self.qk_norm(query_states) - key_states = self.qk_norm(key_states) - - if self.is_llama4: - query_states = self.apply_attention_scaling(input_shape, cache_position, query_states) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - # print(f"cache_position: {cache_position}") - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1 - ) - attn_output = sum( - [ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ] - ) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def apply_attention_scaling(self, input_shape, cache_position, query_states): - # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers - if self.llama4_attn_config.attn_temperature_tuning and not self.use_rope: - attn_scales = ( - torch.log( - torch.floor( - (cache_position.float() + 1.0) / self.llama4_attn_config.floor_scale - ) - + 1.0 - ) - * self.llama4_attn_config.attn_scale - + 1.0 - ) - attn_scales = attn_scales.view((*input_shape, 1, 1)).transpose(1, 2) - query_states = (query_states * attn_scales).to(query_states.dtype) - return query_states - return query_states - - -class DeciLMFlashAttention2(DeciLMAttention): - """ - DeciLM flash attention module. This module inherits from `DeciLMAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is - # bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is - # used to handle this difference. - # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case - # q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - self.sliding_window = self.attention_config.prefill_sliding_window - - self.pre_attention_identity_query = nn.Identity() # for debugging hooks - self.pre_attention_identity_key = nn.Identity() # for debugging hooks - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.LongTensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] - | None = None, # will become mandatory in v4.45 - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( - 1, 2 - ) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - if self.config.position_embedding_type in ["rope", "mistral_yarn"]: - # llama4 doesn't use flash attention - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE " - "embeddings internally through `position_ids` (2D tensor with the indexes of the " - "tokens), to using externally computed `position_embeddings` (Tuple of tensors, " - "containing cos and sin). In v4.45 `position_ids` will be removed and " - "`position_embeddings` will be mandatory." - ) - freqs_cis = self.rotary_emb(value_states, position_ids) - else: - freqs_cis = position_embeddings - - query_states, key_states = self.apply_rope_fn(query_states, key_states, freqs_cis) - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freq_cis) - # print(f"applying even odd rope") - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV - # cache to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (DeciLMRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - query_states = self.pre_attention_identity_query(query_states) - key_states = self.pre_attention_identity_key(key_states) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -DECILM_ATTENTION_CLASSES = { - "eager": DeciLMAttention, - "flash_attention_2": DeciLMFlashAttention2, -} - - -class DeciLMLlama4TextAttention(Llama4TextAttention): - def __init__(self, config: DeciLMConfig, layer_idx: int, attention_config: AttentionConfig): - llama4_text_config = Llama4TextConfig( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_attention_heads // attention_config.n_heads_in_group, - head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), - attn_scale=attention_config.llama4.attn_scale, - floor_scale=attention_config.llama4.floor_scale, - attn_temperature_tuning=attention_config.llama4.attn_temperature_tuning, - attention_dropout=attention_config.llama4.attention_dropout, - use_qk_norm=attention_config.llama4.use_qk_norm, - use_rope=attention_config.llama4.use_rope, - rms_norm_eps=config.rms_norm_eps, - attention_bias=config.attention_bias, - attn_implementation=config.llama4_attn_implementation, - rope_scaling=config.rope_scaling, - max_position_embeddings=config.max_position_embeddings, - attention_chunk_size=attention_config.llama4.attention_chunk_size, - ) - super().__init__(llama4_text_config, layer_idx, use_rope=attention_config.llama4.use_rope) - - -class DeciLMDecoderLayer(nn.Module): - # DeciLM-specific code - def __init__(self, config: DeciLMConfig, layer_idx: int | tuple[int, ...]): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.block_config = config.get_block_config(layer_idx) - - self.attention_config = self.block_config.attention - self.ffn_config = self.block_config.ffn - self.layer_idx = layer_idx - - if not config._attn_implementation: - config._attn_implementation = "eager" - - if not self.attention_config.no_op: - self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - if self.attention_config.replace_with_linear: - self.self_attn = DeciLMLinearAttention(config) - elif self.attention_config.is_mamba: - self.self_attn = DeciLMMambaMixer(config, self.attention_config.mamba) - elif not self.attention_config.is_llama4: - self.self_attn = DECILM_ATTENTION_CLASSES[config._attn_implementation]( - config=config, attention_config=self.attention_config, layer_idx=layer_idx - ) - else: - self.self_attn = DeciLMLlama4TextAttention(config, layer_idx, self.attention_config) - - if not (self.ffn_config.no_op or self.attention_config.is_mamba): - if getattr(self.ffn_config, "hidden_act", None) is None: - print(f"WARNING: FFN hidden_act is None for layer {layer_idx}") - - self.post_attention_layernorm = DeciLMRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - if self.ffn_config.replace_with_linear: - self.mlp = DeciLMLinearMLP(config) - elif self.ffn_config.is_moe: - self.mlp = DeciLMMoe(config, self.ffn_config) - else: - self.mlp = ( - DeciLMGatedMLP(config, self.ffn_config) - if self.ffn_config.gated - else DeciLMVanillaMLP(config, self.ffn_config) - ) - - self.is_sliding = self.attention_config.is_sliding - self.sliding_window = self.attention_config.prefill_sliding_window - self.return_only_hidden_states = self.config.block_return_only_hidden_states - - @property - def device(self): - try: - return next(self.parameters()).device - except StopIteration: - return None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, - output_attentions: bool | None = False, - output_router_logits: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] - | None = None, # necessary, but kept here for BC - **kwargs, - ) -> ( - tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None] - | torch.FloatTensor - ): - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - paramz = list(self.parameters()) - device = paramz[0].device if len(paramz) > 0 else None - if isinstance(hidden_states, tuple): - # could happen when sewing kit sends the output of the previous layer - # to this layer without going through the model forward unpacking code. - # can be avoided by using config.block_return_only_hidden_states=True - hidden_states = hidden_states[0] - - hidden_states = hidden_states.to(device) - - if cache_position is not None: - cache_position = cache_position.to(device) - - if self.attention_config.llama4 is not None: - # chunk_size = self.attention_config.llama4.attention_chunk_size - # print(f"pre-llama4_update: {attention_mask=}") - # causal_mask, chunk_causal_mask = self._llama4_update_causal_mask( - # attention_mask, hidden_states, cache_position, past_key_value, output_attentions, use_cache=use_cache, - # ) - # attention_mask = causal_mask if (chunk_size is None) else chunk_causal_mask - # if (past_key_value is not None) and isinstance(attention_mask, BlockMask): - # print(f"pre-adjust: {attention_mask.shape=}") - # print(f"pre-adjust: {hidden_states.shape=}") - # print(f"pre-adjust: {past_key_value.get_seq_length()=}") - # q_len = hidden_states.shape[1] - # kv_len = past_key_value.get_seq_length() - # if kv_len == 0: - # kv_len = q_len - # print(f"pre-adjust: {kv_len=} {q_len=}") - # print(f"post-adjust: {attention_mask.shape=}") - assert self.config.llama4_attn_implementation != "flex_attention", ( - "We have a mask issue with flex attention" - ) - - causal_mask, chunk_causal_mask = self._llama4_update_causal_mask( - attention_mask, - hidden_states, - cache_position, - past_key_value, - output_attentions, - use_cache=use_cache, - ) - is_chunked = self.attention_config.llama4.attention_chunk_size is not None - attention_mask = ( - chunk_causal_mask if is_chunked and (chunk_causal_mask is not None) else causal_mask - ) - - else: - attention_mask = self._llama3_update_causal_mask( - attention_mask, hidden_states, cache_position, past_key_value, output_attentions - ) - if self.attention_config.unshifted_sink and self.attention_config.is_sink: - attention_mask = self._unshifted_sink_mask( - attention_mask, - hidden_states, - self.attention_config.window_length, - self.attention_config.num_sink_tokens, - ) - else: - attention_mask = self._gemma2_window_mask( - attention_mask, hidden_states, past_key_value - ) - - self_attn_weights = None - present_key_value = past_key_value - router_logits = None - - if self.attention_config.no_op: - pass - elif self.attention_config.replace_with_linear or self.attention_config.is_mamba: - if self.attention_config.is_mamba: - assert past_key_value is None, "DeciLM does not support generation with Mamba yet" - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(hidden_states) - hidden_states = residual + hidden_states - else: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - attn_out = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states, self_attn_weights = attn_out[:2] - if len(attn_out) > 2: - present_key_value = attn_out[2] - - hidden_states = residual + hidden_states - - if not self.ffn_config.no_op: - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - # Handle MoE layers differently as they return router logits - if self.ffn_config.is_moe: - hidden_states, router_logits = self.mlp(hidden_states) - else: - hidden_states = self.mlp(hidden_states) - - hidden_states = residual + hidden_states - - if self.return_only_hidden_states: - return hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if output_router_logits and router_logits is not None: - outputs += (router_logits,) - - return outputs - - def _gemma2_window_mask( - self, - attention_mask: torch.Tensor | None, - hidden_states: torch.Tensor, - past_key_value: VariableCache | None, - ) -> torch.Tensor | None: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # Flash-attn is a 2D tensor - if self.config._attn_implementation == "flash_attention_2": - if past_key_value is not None: # when decoding - attention_mask = attention_mask[:, -self.sliding_window :] - else: - min_dtype = torch.finfo(hidden_states.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] - return attention_mask - - def _unshifted_sink_mask( - self, - attention_mask: torch.Tensor, - hidden_states: torch.Tensor, - window_length: int, - num_sink_tokens: int | None, - ) -> torch.Tensor: - assert self.config._attn_implementation == "eager", ( - "Unshifted sink is only supported in 'eager' mode." - ) - assert attention_mask is not None, "The attention mask seems to not be prepared" - - attention_mask = attention_mask.clone() - min_dtype = torch.finfo(hidden_states.dtype).min - - if window_length == 0: - attention_mask = torch.full_like(attention_mask, fill_value=min_dtype) - else: - query_length = attention_mask.shape[-2] - is_decode = query_length == 1 - if is_decode: - attention_mask[:, :, :, :-window_length] = min_dtype - else: - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-window_length - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - - attention_mask[:, :, :, :num_sink_tokens] = 0 - return attention_mask - - def _llama3_update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is - # 2D and of dynamic length even when the static KV cache is used. This is an issue for - # torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic - # shapes. (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. - # A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`. - # See more context in https://github.com/huggingface/transformers/pull/29114 - - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache" - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not using_static_cache - and not output_attentions - ): - if ( - AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ) - and not self.is_sliding - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_length() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @torch.compiler.disable(recursive=False) # the operations in this method are not compilable - def _llama4_update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache | None, - output_attentions: bool = False, - chunked_attention_mask=None, - use_cache=True, - ): - attn_implementation = self.config.llama4_attn_implementation - - if attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return ( - attention_mask, - attention_mask, - ) # flash does not support chunked attn TODO support flash - return None, None - - if attn_implementation not in ["sdpa", "flex_attention", "eager"]: - return None, None - - sequence_length = input_tensor.shape[1] - cache_position = cache_position.to(self.device) - attention_chunk_size = self.attention_config.llama4.attention_chunk_size - if attention_chunk_size is None: - # let the function build some chunked mask, we won't use it since it's not a chunked - # attention layer. We still need to know the chunk size for this if statement that - # comes later on: if attn_implementation == "sdpa" and chunked_attention_mask is not None - # otherwise the mask dtype is wrong for sdpa :bufo-wat: - attention_chunk_size = self.config.get_min_attention_chunk_size() - if attention_chunk_size is None: - logger.warning_once( - "Could not infer attention_chunk_size since the model (or the model shard) " - "has no chunked attention, using 8192 as default for mask construction" - ) - attention_chunk_size = 8192 - - first_cache_position = cache_position[0] - - if past_key_values is not None: - full_cache_length = past_key_values.get_max_cache_shape() or sequence_length - else: - full_cache_length = ( - attention_mask.shape[-1] if attention_mask is not None else sequence_length - ) - - cond1 = first_cache_position >= attention_chunk_size - cond2 = (first_cache_position < attention_chunk_size) & ( - first_cache_position + sequence_length > attention_chunk_size - ) - key_length = ( - torch.where( - cond1, - attention_chunk_size + sequence_length - 1, - torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), - ) - if use_cache - else full_cache_length - ) - - if attn_implementation == "flex_attention": - raise NotImplementedError("DeciLM Llama4 does not support flex attention") - # if isinstance(attention_mask, torch.Tensor): - # offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0)) - # chunked_attention_mask = make_flex_block_causal_mask( - # attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets - # ) - # attention_mask = make_flex_block_causal_mask( - # attention_mask, - # query_length=sequence_length, - # key_length=full_cache_length, - # offsets=(first_cache_position, 0), - # ) - # return attention_mask, chunked_attention_mask - # if isinstance(attention_mask, BlockMask): - # return attention_mask, chunked_attention_mask - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - dtype, device = input_tensor.dtype, input_tensor.device - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=max(full_cache_length, attention_chunk_size), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - min_dtype=torch.finfo(dtype).min, - ) - if full_cache_length > attention_chunk_size: - start_idx = max(first_cache_position - attention_chunk_size + 1, 0) - end_idx = start_idx + key_length - chunked_attention_mask = self.create_chunked_attention_mask( - attention_chunk_size, - start=start_idx, # same offset as with flex - end=end_idx, - device=device, - ) - - ### Deci: we added this code to patch a bug in transformers - if attention_mask is None: - if past_key_values is not None: - raise NotImplementedError("We only support attention_mask=None is prefill") - attention_mask = torch.ones( - input_tensor.shape[0], input_tensor.shape[1], device=device, dtype=torch.long - ) - - local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well - # It may be smaller than attention_chunk_size -> pad it - requires_padding = local_attention_mask.shape[-1] < attention_chunk_size - if requires_padding: - local_attention_mask = nn.functional.pad( - local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) - ) - # Depending on the padding, take the query tokens from the end or the cache_position - if not requires_padding: - chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] - else: - chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] - - chunked_attention_mask = chunked_attention_mask.expand( - input_tensor.shape[0], -1, -1, -1 - ) - chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] - if attn_implementation == "eager": - min_dtype = torch.finfo(dtype).min - chunked_attention_mask = torch.where( - chunked_attention_mask == 0, min_dtype, 0.0 - ).to(dtype) - - # print(f"{output_attentions=}") - - if ( - attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu"] - and attention_mask.ndim == 4 - and not output_attentions # Only unmask for 4d masks - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if attn_implementation == "sdpa" and chunked_attention_mask is not None: - chunked_attention_mask = chunked_attention_mask.bool() - causal_mask = causal_mask.bool() - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=first_cache_position, - is_training=self.training, - ): - causal_mask = None - return causal_mask, chunked_attention_mask - - def create_chunked_attention_mask( - self, attention_chunk_size: int, start: int, end: int, device: torch.device - ) -> torch.Tensor: - """ - Generate the following: - - 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | - '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | - '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | - 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | - '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | - '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | - - If the chunk size is 3. - This can just be appplied over the already created attention mask - """ - arange_vector = torch.arange(start, end, device=device) - block_pos = torch.abs( - arange_vector.unsqueeze(0) // attention_chunk_size - - arange_vector.unsqueeze(1) // attention_chunk_size - ) - token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) - mask = (block_pos == 0) & (token_pos <= 0) - return mask.to(device) - - -class DeciLMMultiDecoderLayer(nn.Module): - def __init__(self, config: DeciLMConfig, layer_idx: int): - super().__init__() - self.config = config - block_config = config.block_configs[layer_idx] - assert block_config.parallel_blocks is not None - num_parallel_blocks = len(block_config.parallel_blocks) - self.parallel_blocks = nn.ModuleList( - [ - DeciLMDecoderLayer(config, (layer_idx, internal_block_idx)) - for internal_block_idx in range(num_parallel_blocks) - ] - ) - - def forward( - self, - hidden_states: torch.Tensor, - *args, - **kwargs, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - block_outputs = [block(hidden_states, *args, **kwargs) for block in self.parallel_blocks] - output_hidden_states = [ - out[0].to(hidden_states.device) - if isinstance(out, tuple) - else out.to(hidden_states.device) - for out in block_outputs - ] - output_hidden_states = torch.stack(output_hidden_states, dim=0).sum(dim=0) - output_hidden_states = ( - output_hidden_states - (len(self.parallel_blocks) - 1) * hidden_states - ) - - if self.config.block_return_only_hidden_states: - return output_hidden_states - - other_outputs = block_outputs[0][1:] - outputs = (output_hidden_states, *other_outputs) - return outputs - - -######################################################################## -# DeciLM-specific code -######################################################################## - - -def _find_multiple(n: int, k: int) -> int: - # DeciLM-specific code - if n % k == 0: - return n - return n + k - (n % k) - - class DeciLMMoe(nn.Module): """ Implementation of Mixture of Experts module for DeciLM. @@ -1680,64 +203,6 @@ def extra_repr(self) -> str: ) -class DeciLMLinearMLP(nn.Module): - # DeciLM-specific code - def __init__( - self, - config: DeciLMConfig, - ): - super().__init__() - self.linear_mlp = nn.Linear( - in_features=config.hidden_size, out_features=config.hidden_size, bias=False - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear_mlp.forward(x) - - -class DeciLMLinearAttention(nn.Module): - # DeciLM-specific code - def __init__( - self, - config: DeciLMConfig, - ): - super().__init__() - self.linear_attn = nn.Linear( - in_features=config.hidden_size, out_features=config.hidden_size, bias=False - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear_attn.forward(x) - - -def sparsity_backward_hook(*args, **kwargs): - raise NotImplementedError( - "No support for sparsity when training HF DeciLM (inference is ok though)" - ) - - -class DeciLMMambaMixer(nn.Module): - def __init__( - self, - config: DeciLMConfig, - mamba_config: MambaConfig, - ): - super().__init__() - self.mamba_mixer = MambaMixerMegatron( - d_model=config.hidden_size, - d_state=mamba_config.state_dim, - nheads=mamba_config.num_heads, - headdim=mamba_config.head_dim, - ngroups=mamba_config.num_groups, - ) - - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - x = x.permute([1, 0, 2]) # MambaMixerMegatron expects [Sequence, Batch, Embedding] - out = self.mamba_mixer(x) - out = out.permute([1, 0, 2]) # go back to [Batch, Sequence, Embedding] - return out - - class LMHead(nn.Linear): """ Special class to allow FSDP wrapping without affecting other Linear layers in the model. diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py deleted file mode 100644 index 7257800678..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py +++ /dev/null @@ -1,498 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import torch - - -@dataclass -class AttentionMaskConverter: - """ - A utility attention mask class that allows one to: - - Create a causal 4d mask - - Create a causal 4d mask with slided window - - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, - key_value_length) that can be multiplied with attention scores - - Examples: - - ```python - >>> import torch - >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter - - >>> converter = AttentionMaskConverter(True) - >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) - tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], - [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], - [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], - [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], - [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) - ``` - - Parameters: - is_causal (`bool`): - Whether the attention mask should be a uni-directional (causal) or bi-directional mask. - - sliding_window (`int`, *optional*): - Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. - """ - - is_causal: bool - sliding_window: int - - def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): - self.is_causal = is_causal - self.sliding_window = sliding_window - - if self.sliding_window is not None and self.sliding_window <= 0: - raise ValueError( - f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" - ) - - def to_causal_4d( - self, - batch_size: int, - query_length: int, - key_value_length: int, - dtype: torch.dtype, - device: Union[torch.device, "str"] = "cpu", - ) -> Optional[torch.Tensor]: - """ - Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative - bias to upper right hand triangular matrix (causal mask). - """ - if not self.is_causal: - raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") - - # If shape is not cached, create a new causal mask and cache it - input_shape = (batch_size, query_length) - past_key_values_length = key_value_length - query_length - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if input_shape[-1] > 1 or self.sliding_window is not None: - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - - return causal_4d_mask - - def to_4d( - self, - attention_mask_2d: torch.Tensor, - query_length: int, - dtype: torch.dtype, - key_value_length: Optional[int] = None, - ) -> torch.Tensor: - """ - Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, - key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is - causal, a causal mask will be added. - """ - input_shape = (attention_mask_2d.shape[0], query_length) - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: - if key_value_length is None: - raise ValueError( - "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." - ) - - past_key_values_length = key_value_length - query_length - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=attention_mask_2d.device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - elif self.sliding_window is not None: - raise NotImplementedError("Sliding window is currently only implemented for causal masking") - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( - attention_mask_2d.device - ) - - if causal_4d_mask is not None: - expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) - - # expanded_attn_mask + causal_4d_mask can cause some overflow - expanded_4d_mask = expanded_attn_mask - - return expanded_4d_mask - - @staticmethod - def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, - ): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - - # add lower triangular sliding window mask if necessary - if sliding_window is not None: - diagonal = past_key_values_length - sliding_window - 1 - - context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) - mask.masked_fill_(context_mask, torch.finfo(dtype).min) - - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - @staticmethod - def _unmask_unattended( - expanded_mask: torch.FloatTensor, - min_dtype: float, - ): - # fmt: off - """ - Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when - using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - Details: https://github.com/pytorch/pytorch/issues/110213 - - `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. - `attention_mask` is [bsz, src_seq_len]. - - The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. - - For example, if `expanded_mask` is (e.g. here left-padding case) - ``` - [[[[0, 0, 0], - [0, 0, 0], - [0, 0, 1]]], - [[[1, 0, 0], - [1, 1, 0], - [1, 1, 1]]], - [[[0, 0, 0], - [0, 1, 0], - [0, 1, 1]]]] - ``` - then the modified `expanded_mask` will be - ``` - [[[[1, 1, 1], <-- modified - [1, 1, 1], <-- modified - [0, 0, 1]]], - [[[1, 0, 0], - [1, 1, 0], - [1, 1, 1]]], - [[[1, 1, 1], <-- modified - [0, 1, 0], - [0, 1, 1]]]] - ``` - """ - # fmt: on - if expanded_mask.dtype == torch.bool: - raise ValueError( - "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." - ) - - return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) - - @staticmethod - def _ignore_causal_mask_sdpa( - attention_mask: Optional[torch.Tensor], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, - is_training: bool = False, - ) -> bool: - """ - Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. - - In case no token is masked in the `attention_mask` argument, if `query_length == 1` or - `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, - allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). - """ - - _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] - key_value_length = query_length + past_key_values_length - - is_tracing = ( - torch.jit.is_tracing() - or isinstance(inputs_embeds, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - - ignore_causal_mask = False - - if attention_mask is None: - # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or - # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). - # Thus, we only set `ignore_causal_mask = True` if the model is set to training. - # - # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). - if ( - (is_training or not is_tracing) - and (query_length == 1 or key_value_length == query_length) - and (sliding_window is None or key_value_length < sliding_window) - ): - ignore_causal_mask = True - elif sliding_window is None or key_value_length < sliding_window: - if len(attention_mask.shape) == 4: - return False - elif (is_training or not is_tracing) and torch.all(attention_mask == 1): - if query_length == 1 or key_value_length == query_length: - # For query_length == 1, causal attention and bi-directional attention are the same. - ignore_causal_mask = True - - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 - # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. - - return ignore_causal_mask - - -def _prepare_4d_causal_attention_mask( - attention_mask: Optional[torch.Tensor], - input_shape: Union[torch.Size, Tuple, List], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)` - - Args: - attention_mask (`torch.Tensor` or `None`): - A 2D attention mask of shape `(batch_size, key_value_length)` - input_shape (`tuple(int)` or `list(int)` or `torch.Size`): - The input shape should be a tuple that defines `(batch_size, query_length)`. - inputs_embeds (`torch.Tensor`): - The embedded inputs as a torch Tensor. - past_key_values_length (`int`): - The length of the key value cache. - sliding_window (`int`, *optional*): - If the model uses windowed attention, a sliding window should be passed. - """ - attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) - - key_value_length = input_shape[-1] + past_key_values_length - - # 4d mask is passed through the layers - if attention_mask is not None and len(attention_mask.shape) == 2: - attention_mask = attn_mask_converter.to_4d( - attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype - ) - elif attention_mask is not None and len(attention_mask.shape) == 4: - expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) - else: - # if the 4D mask has correct shape - invert it and fill with negative infinity - inverted_mask = 1.0 - attention_mask - attention_mask = inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min - ) - else: - attention_mask = attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - - return attention_mask - - -# Adapted from _prepare_4d_causal_attention_mask -def _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask: Optional[torch.Tensor], - input_shape: Union[torch.Size, Tuple, List], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, -): - """ - Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. - - In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and - `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, - allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). - """ - attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) - - key_value_length = input_shape[-1] + past_key_values_length - - # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` - # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - is_tracing = ( - torch.jit.is_tracing() - or isinstance(inputs_embeds, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - - ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - - if ignore_causal_mask: - expanded_4d_mask = None - elif attention_mask is None: - expanded_4d_mask = attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - else: - if attention_mask.dim() == 4: - # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing - if attention_mask.max() != 0: - raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") - expanded_4d_mask = attention_mask - else: - expanded_4d_mask = attn_mask_converter.to_4d( - attention_mask, - input_shape[-1], - dtype=inputs_embeds.dtype, - key_value_length=key_value_length, - ) - - # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - if not is_tracing and expanded_4d_mask.device.type == "cuda": - expanded_4d_mask = AttentionMaskConverter._unmask_unattended( - expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min - ) - - return expanded_4d_mask - - -def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)` - - Args: - mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` - dtype (`torch.dtype`): - The torch dtype the created mask shall have. - tgt_len (`int`): - The target length or query length the created mask shall have. - """ - return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) - - -def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)` - - Args: - mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` - dtype (`torch.dtype`): - The torch dtype the created mask shall have. - tgt_len (`int`): - The target length or query length the created mask shall have. - """ - _, key_value_length = mask.shape - tgt_len = tgt_len if tgt_len is not None else key_value_length - - is_tracing = ( - torch.jit.is_tracing() - or isinstance(mask, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - - # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. - if not is_tracing and torch.all(mask == 1): - return None - else: - return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) - - -def _create_4d_causal_attention_mask( - input_shape: Union[torch.Size, Tuple, List], - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, -) -> Optional[torch.Tensor]: - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` - - Args: - input_shape (`tuple(int)` or `list(int)` or `torch.Size`): - The input shape should be a tuple that defines `(batch_size, query_length)`. - dtype (`torch.dtype`): - The torch dtype the created mask shall have. - device (`int`): - The torch device the created mask shall have. - sliding_window (`int`, *optional*): - If the model uses windowed attention, a sliding window should be passed. - """ - attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) - - key_value_length = past_key_values_length + input_shape[-1] - attention_mask = attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device - ) - - return attention_mask diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py deleted file mode 100644 index 9e9fb46ca4..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py +++ /dev/null @@ -1,363 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors -import inspect -import os -from typing import Optional, Tuple, Union - - -import torch -import torch.nn.functional as F - -from functools import lru_cache -import importlib.metadata -import importlib.util -from packaging import version - -from transformers.utils import is_flash_attn_2_available - - -if is_flash_attn_2_available(): - try: - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - from flash_attn import flash_attn_func, flash_attn_varlen_func - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - except ImportError: - raise "Unable to import flash_attn" - - -def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: - # Check if the package spec exists and grab its version to avoid importing a local directory - package_exists = importlib.util.find_spec(pkg_name) is not None - package_version = "N/A" - if package_exists: - try: - # Primary method to get the package version - package_version = importlib.metadata.version(pkg_name) - except importlib.metadata.PackageNotFoundError: - # Fallback method: Only for "torch" and versions containing "dev" - if pkg_name == "torch": - try: - package = importlib.import_module(pkg_name) - temp_version = getattr(package, "__version__", "N/A") - # Check if the version contains "dev" - if "dev" in temp_version: - package_version = temp_version - package_exists = True - else: - package_exists = False - except ImportError: - # If the package can't be imported, it's not available - package_exists = False - else: - # For packages other than "torch", don't attempt the fallback and set as not available - package_exists = False - if return_version: - return package_exists, package_version - else: - return package_exists - - -@lru_cache() -def is_flash_attn_greater_or_equal(library_version: str): - if not _is_package_available("flash_attn"): - return False - - return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) - - -def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: - """ - Retrieves indexing data required to repad unpadded (ragged) tensors. - - Arguments: - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. - cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - max_seqlen_in_batch (`int`): - Maximum sequence length in batch. - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _upad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, -): - """ - Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. - - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary - tensors for query, key, value tensors. - - Arguments: - query_layer (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - query_length (`int`): - Target length. - - Return: - query_layer (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -def prepare_fa2_from_position_ids(query, key, value, position_ids): - """ - This function returns necessary arguments to call `flash_attn_varlen_func`. - All three query, key, value states will be flattened. - Cummulative lengths of each examples in the batch will be extracted from position_ids. - - NOTE: ideally cummulative lengths should be prepared at the data collator stage - - Arguments: - query (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - position_ids (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - query (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - position_ids = position_ids.flatten() - indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) - - cu_seq_lens = torch.cat( - ( - indices_q[position_ids == 0], - torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), - ) - ) - - max_length = position_ids.max() + 1 - - return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) - - -def _flash_attention_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - is_causal: bool, - dropout: float = 0.0, - position_ids: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - sliding_window: Optional[int] = None, - use_top_left_mask: bool = False, - softcap: Optional[float] = None, - deterministic: bool = None, -): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_top_left_mask (`bool`, defaults to `False`): - flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. - softcap (`float`, *optional*): - Softcap for the attention logits, used e.g. in gemma2. - deterministic (`bool`, *optional*): - Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. - """ - if not use_top_left_mask: - causal = is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. - causal = is_causal and query_length != 1 - - # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). - use_sliding_windows = ( - _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window - ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - - if is_flash_attn_greater_or_equal("2.4.1"): - if deterministic is None: - deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - flash_kwargs["deterministic"] = deterministic - - if softcap is not None: - flash_kwargs["softcap"] = softcap - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, - ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - - # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing - # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. - # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): - batch_size = query_states.size(0) - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, - ) - - attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) - - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs - ) - - return attn_output diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py deleted file mode 100644 index aa9f07b879..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py +++ /dev/null @@ -1,1768 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch - -from transformers.utils import ModelOutput - - -@dataclass -class BaseModelOutput(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithNoAttention(ModelOutput): - """ - Base class for model's outputs, with potential hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPooling(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPoolingAndNoAttention(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state after a pooling operation on the spatial dimensions. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPast(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithCrossAttentions(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class MoECausalLMOutputWithPast(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden - states terms, to train a MoE model. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): - z_loss for the sparse modules. - aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): - aux_loss for the sparse modules. - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse - modules. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - z_loss: torch.FloatTensor = None - aux_loss: torch.FloatTensor = None - router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MoEModelOutput(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary - loss and the z_loss for Mixture of Experts models. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - router_probs: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MoeModelOutputWithPast(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary - loss for Mixture of Experts models. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MoeCausalLMOutputWithPast(ModelOutput): - """ - Base class for causal language model (or autoregressive) with mixture of experts outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - - aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): - aux_loss for the sparse modules. - - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary - loss for Mixture of Experts models. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - aux_loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as - Mixture of Expert's router hidden states terms, to train a MoE model. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary - loss and the z_loss for Mixture of Experts models. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - router_probs: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class Seq2SeqModelOutput(ModelOutput): - """ - Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential - decoding. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqMoEModelOutput(ModelOutput): - """ - Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential - decoding. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse - modules. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class CausalLMOutput(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class CausalLMOutputWithPast(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class CausalLMOutputWithCrossAttentions(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Cross attentions weights after the attention softmax, used to compute the weighted average in the - cross-attention heads. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key, - value states of the self-attention and the cross-attention layers if model is used in encoder-decoder - setting. Only relevant if `config.is_decoder = True`. - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class SequenceClassifierOutputWithPast(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class MaskedLMOutput(ModelOutput): - """ - Base class for masked language models outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Masked language modeling (MLM) loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqLMOutput(ModelOutput): - """ - Base class for sequence-to-sequence language models outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqMoEOutput(ModelOutput): - """ - Base class for sequence-to-sequence language models outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts - models. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - encoder_z_loss: torch.FloatTensor = None - decoder_z_loss: torch.FloatTensor = None - encoder_aux_loss: torch.FloatTensor = None - decoder_aux_loss: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class NextSentencePredictorOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided): - Next sequence prediction (classification) loss. - logits (`torch.FloatTensor` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class SequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqSequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sequence-to-sequence sentence classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class MultipleChoiceModelOutput(ModelOutput): - """ - Base class for outputs of multiple choice models. - - Args: - loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): - Classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): - *num_choices* is the second dimension of the input tensors. (see *input_ids* above). - - Classification scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class TokenClassifierOutput(ModelOutput): - """ - Base class for outputs of token classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : - Classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): - Classification scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class QuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of question answering models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - start_logits: torch.FloatTensor = None - end_logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of sequence-to-sequence question answering models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - start_logits: torch.FloatTensor = None - end_logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class SemanticSegmenterOutput(ModelOutput): - """ - Base class for outputs of semantic segmentation models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): - Classification scores for each pixel. - - - - The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is - to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the - original image size as post-processing. You should always check your logits shape and resize as needed. - - - - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class ImageClassifierOutput(ModelOutput): - """ - Base class for outputs of image classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states - (also called feature maps) of the model at the output of each stage. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class ImageClassifierOutputWithNoAttention(ModelOutput): - """ - Base class for outputs of image classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also - called feature maps) of the model at the output of each stage. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class DepthEstimatorOutput(ModelOutput): - """ - Base class for outputs of depth estimation models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`): - Predicted depth for each pixel. - - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - predicted_depth: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class ImageSuperResolutionOutput(ModelOutput): - """ - Base class for outputs of image super resolution models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Reconstruction loss. - reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Reconstructed images, possibly upscaled. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states - (also called feature maps) of the model at the output of each stage. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - reconstruction: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Wav2Vec2BaseModelOutput(ModelOutput): - """ - Base class for models that have been trained with the Wav2Vec2 loss objective. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): - Sequence of extracted feature vectors of the last convolutional layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - extract_features: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class XVectorOutput(ModelOutput): - """ - Output type of [`Wav2Vec2ForXVector`]. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): - Classification hidden states before AMSoftmax. - embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): - Utterance embeddings used for vector similarity-based retrieval. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - embeddings: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BackboneOutput(ModelOutput): - """ - Base class for outputs of backbones. - - Args: - feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): - Feature maps of the stages. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`, - depending on the backbone. - - Hidden-states of the model at the output of each stage plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Only applicable if the backbone uses attention. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - feature_maps: Tuple[torch.FloatTensor] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPoolingAndProjection(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`. - - Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - projection_state: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class Seq2SeqSpectrogramOutput(ModelOutput): - """ - Base class for sequence-to-sequence spectrogram outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Spectrogram generation loss. - spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): - The predicted spectrogram. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - spectrogram: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqTSModelOutput(ModelOutput): - """ - Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up - sequential decoding. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): - Shift values of each time series' context window which is used to give the model inputs of the same - magnitude and then used to shift back to the original magnitude. - scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): - Scaling values of each time series' context window which is used to give the model inputs of the same - magnitude and then used to rescale back to the original magnitude. - static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): - Static features of each time series' in a batch which are copied to the covariates at inference time. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - loc: Optional[torch.FloatTensor] = None - scale: Optional[torch.FloatTensor] = None - static_features: Optional[torch.FloatTensor] = None - - -@dataclass -class Seq2SeqTSPredictionOutput(ModelOutput): - """ - Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the - chosen distribution. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided): - Distributional loss. - params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`): - Parameters of the chosen distribution. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): - Shift values of each time series' context window which is used to give the model inputs of the same - magnitude and then used to shift back to the original magnitude. - scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): - Scaling values of each time series' context window which is used to give the model inputs of the same - magnitude and then used to rescale back to the original magnitude. - static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): - Static features of each time series' in a batch which are copied to the covariates at inference time. - """ - - loss: Optional[torch.FloatTensor] = None - params: Optional[Tuple[torch.FloatTensor]] = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - loc: Optional[torch.FloatTensor] = None - scale: Optional[torch.FloatTensor] = None - static_features: Optional[torch.FloatTensor] = None - - -@dataclass -class SampleTSPredictionOutput(ModelOutput): - """ - Base class for time series model's predictions outputs that contains the sampled values from the chosen - distribution. - - Args: - sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`): - Sampled values from the chosen distribution. - """ - - sequences: torch.FloatTensor = None - - -@dataclass -class MaskedImageModelingOutput(ModelOutput): - """ - Base class for outputs of masked image completion / in-painting models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): - Reconstruction loss. - reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Reconstructed / completed images. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or - when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states - (also called feature maps) of the model at the output of each stage. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when - `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - reconstruction: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - @property - def logits(self): - warnings.warn( - "logits attribute is deprecated and will be removed in version 5 of Transformers." - " Please use the reconstruction attribute to retrieve the final output instead.", - FutureWarning, - ) - return self.reconstruction diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py deleted file mode 100644 index a1b413b0e0..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from torch import nn - -ALL_LAYERNORM_LAYERS = [nn.LayerNorm] \ No newline at end of file diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py deleted file mode 100644 index 3dac4a51c6..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py +++ /dev/null @@ -1,2535 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors -import copy -import importlib.metadata -import json -import os -from collections.abc import Iterable -from dataclasses import dataclass -from typing import Any - -import torch -from packaging import version -from transformers.configuration_utils import PretrainedConfig -from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 -from transformers.utils import ( - is_hqq_available, - is_optimum_quanto_available, - is_torch_greater_or_equal, - logging, -) - -if is_hqq_available(): - from hqq.core.quantize import Quantizer as HQQQuantizer - -logger = logging.get_logger(__name__) - - -class Cache: - """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. - """ - - is_compileable = False - - def __init__(self): - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - def get_max_cache_shape(self) -> int | None: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: int | None = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( - 0, beam_idx.to(device) - ) - if self.value_cache[layer_idx].numel(): - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( - 0, beam_idx.to(device) - ) - - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None - - -@dataclass -class CacheConfig: - """ - Base class for cache configs - """ - - cache_implementation: None - - @classmethod - def from_dict(cls, config_dict, **kwargs): - """ - Constructs a CacheConfig instance from a dictionary of parameters. - Args: - config_dict (Dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - - Returns: - CacheConfig: Instance of CacheConfig constructed from the dictionary. - """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file - def to_json_file(self, json_file_path: str | os.PathLike): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `QuantizationConfig()` is serialized to JSON file. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - writer.write(json_string) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict - def to_dict(self) -> dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - return copy.deepcopy(self.__dict__) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - def to_json_string(self): - """ - Serializes this instance to a JSON formatted string. - Returns: - str: JSON formatted string representing the configuration instance. - """ - return json.dumps(self.__dict__, indent=2) + "\n" - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update - def update(self, **kwargs): - """ - Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, - returning all the unused kwargs. - - Args: - kwargs (`Dict[str, Any]`): - Dictionary of attributes to tentatively update this class. - - Returns: - `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. - """ - to_remove = [] - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - to_remove.append(key) - - # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs - - -@dataclass -class QuantizedCacheConfig(CacheConfig): - """ - Configuration class for quantized cache settings. - - Attributes: - backend (`str`, *optional*, defaults to `"quanto"`): - Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] - nbits (`Optional[int]`, *optional*, defaults to 4): - Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. - axis_key (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - axis_value (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - q_group_size (`Optional[int]`, *optional*, defaults to 64): - Size of the quantization group, should be a divisor of the model's hidden dimension. - Defaults to 64. - residual_length (`Optional[int]`, *optional*, defaults to 128): - Length of the residual cache which will always be stored in original precision. - Defaults to 128. - compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. - device (`str`, *optional*, defaults to `"cpu"`): - Device on which to perform computations, should be same as the model's device. - """ - - def __init__( - self, - backend: str = "quanto", - nbits: int | None = 4, - axis_key: int | None = 0, - axis_value: int | None = 0, - q_group_size: int | None = 64, - residual_length: int | None = 128, - compute_dtype: torch.dtype | None = torch.float16, - device: str | None = "cpu", - ): - self.backend = backend - self.nbits = nbits - self.axis_key = axis_key - self.axis_value = axis_value - self.q_group_size = q_group_size - self.residual_length = residual_length - self.compute_dtype = compute_dtype - self.device = device - - def validate(self): - """Validates if the arguments passed are correct""" - - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - # Check that the values are reasonable in general (nbits, axis) - # Later in QuantizedCache init we check if they are supported for that particular backend - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - incorrect_arg_msg.format( - key="nbits", - correct_value="2 or 4 or 8", - found_value=self.nbits, - ), - ) - if self.q_group_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="q_group_size", - correct_value="a positive integer", - found_value=self.q_group_size, - ), - ) - if self.residual_length < 0: - raise ValueError( - incorrect_arg_msg.format( - key="residual_length", - correct_value="a positive integer", - found_value=self.residual_length, - ), - ) - - if self.axis_key not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_key", - correct_value="`1` or `0`, `-1`", - found_value=self.axis_key, - ), - ) - - if self.axis_value not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_value", - correct_value="`1` or `0` or `-1`", - found_value=self.axis_value, - ), - ) - - -@dataclass -class StaticCacheConfig(CacheConfig): - """ - Configuration class for static cache settings. - """ - - cache_implementation = "static" - - def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): - self.batch_size = batch_size - self.max_cache_len = max_cache_len - self.device = device - - def validate(self): - """Validates if the arguments passed are correct""" - - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - - if self.batch_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="batch_size", - correct_value="> 0", - found_value=self.batch_size, - ), - ) - - if self.max_cache_len <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="max_cache_len", - correct_value="> 0", - found_value=self.max_cache_len, - ), - ) - - -class DynamicCache(Cache): - """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = DynamicCache() - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - DynamicCache() - ``` - """ - - def __init__(self, _distributed_cache_data: Iterable = None) -> None: - super().__init__() - self._seen_tokens = ( - 0 # Used in `generate` to keep tally of how many tokens the cache has seen - ) - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121 - # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the - # iterable contains the key and value states for a layer gathered across replicas by torch.distributed - # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. - if _distributed_cache_data is not None: - for key_states, value_states in _distributed_cache_data: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return (self.key_cache[layer_idx], self.value_cache[layer_idx]) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if key_states is not None: - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[ - layer_idx - ].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat( - [self.key_cache[layer_idx], key_states], dim=-2 - ) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) - <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def get_max_cache_shape(self) -> int | None: - """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" - return None - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: tuple[tuple[torch.FloatTensor]] | None = None - ) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def crop(self, max_length: int): - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - self._seen_tokens = max_length - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCache"]: - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - out = [] - for i in range(0, full_batch_size, split_size): - current_split = DynamicCache() - current_split._seen_tokens = self._seen_tokens - current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] - current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] - out.append(current_split) - return out - - @classmethod - def from_batch_splits(cls, splits: list["DynamicCache"]) -> "DynamicCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - cache = cls() - for idx in range(len(splits[0])): - key_cache = [ - current.key_cache[idx] for current in splits if current.key_cache[idx].numel() - ] - value_cache = [ - current.value_cache[idx] for current in splits if current.value_cache[idx].numel() - ] - if key_cache != []: - layer_keys = torch.cat(key_cache, dim=0) - layer_values = torch.cat(value_cache, dim=0) - cache.update(layer_keys, layer_values, idx) - return cache - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave( - repeats, dim=0 - ) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] - - -# Utilities for `DynamicCache` <> torch.export support -def _flatten_dynamic_cache( - dynamic_cache: DynamicCache, -): - """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" - if not isinstance(dynamic_cache, DynamicCache): - raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") - - if not is_torch_greater_or_equal_than_2_6: - logger.warning_once( - "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." - ) - - # NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten(dictionary) - - -def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - -def _unflatten_dynamic_cache( - values, - context: torch.utils._pytree.Context, -): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - for k, v in dictionary.items(): - setattr(cache, k, v) - return cache - - -def _flatten_dynamic_cache_for_fx(cache, spec): - dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), - } - return torch.utils._pytree.tree_flatten(dictionary)[0] - - -if is_torch_greater_or_equal("2.3"): - torch.utils._pytree.register_pytree_node( - DynamicCache, - _flatten_dynamic_cache, - _unflatten_dynamic_cache, - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, - ) - # TODO (tmanlaibaatar) This won't be needed in torch 2.7. - torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) - - -class OffloadedCache(DynamicCache): - """ - A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. - Useful for generating from models with very long context. - - In addition to the default accelerator stream, where all forward() computations happen, - this class uses another stream, the prefetch stream, which it creates itself. - Since scheduling of operations on separate streams happens independently, this class uses - the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. - The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to - ensure the eviction is scheduled after all computations on that cache are finished. - """ - - def __init__(self) -> None: - if not ( - torch.cuda.is_available() - or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) - ): - raise RuntimeError( - "OffloadedCache can only be used with a GPU" - + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") - ) - - super().__init__() - self.original_device = [] - self.prefetch_stream = None - self.prefetch_stream = ( - torch.Stream() - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.Stream() - ) - self.beam_idx = None # used to delay beam search operations - - def prefetch_layer(self, layer_idx: int): - "Starts prefetching the next layer cache" - if layer_idx < len(self): - with ( - self.prefetch_stream - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.stream(self.prefetch_stream) - ): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to( - device, non_blocking=True - ) - - def evict_previous_layer(self, layer_idx: int): - "Moves the previous layer cache to the CPU" - if len(self) > 2: - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - prev_layer_idx = (layer_idx - 1) % len(self) - self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to( - "cpu", non_blocking=True - ) - self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to( - "cpu", non_blocking=True - ) - - def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: - "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." - if layer_idx < len(self): - # Evict the previous layer if necessary - if is_torch_greater_or_equal("2.7", accept_dev=True): - torch.accelerator.current_stream().synchronize() - else: - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) - # Load current layer cache to its original device if not already there - original_device = self.original_device[layer_idx] - self.prefetch_stream.synchronize() - key_tensor = self.key_cache[layer_idx] - value_tensor = self.value_cache[layer_idx] - # Now deal with beam search ops which were delayed - if self.beam_idx is not None: - self.beam_idx = self.beam_idx.to(original_device) - key_tensor = key_tensor.index_select(0, self.beam_idx) - value_tensor = value_tensor.index_select(0, self.beam_idx) - # Prefetch the next layer - self.prefetch_layer((layer_idx + 1) % len(self)) - return (key_tensor, value_tensor) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Saves the beam indices and reorders the cache when the tensor is back to its device.""" - # We delay this operation until the tensors are back to their original - # device because performing torch.index_select on the CPU is very slow - del self.beam_idx - self.beam_idx = beam_idx.clone() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if len(self.key_cache) < layer_idx: - raise ValueError( - "OffloadedCache does not support model usage where layers are skipped. Use DynamicCache." - ) - elif len(self.key_cache) == layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.original_device.append(key_states.device) - self.evict_previous_layer(layer_idx) - else: - key_tensor, value_tensor = self[layer_idx] - self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError - # if a method is not supposed to be supported in a subclass we should set it to None - from_legacy_cache = None - - to_legacy_cache = None - - -class QuantizedCache(DynamicCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - """ - - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - super().__init__() - self._quantized_key_cache: list[torch.Tensor] = [] - self._quantized_value_cache: list[torch.Tensor] = [] - - self.nbits = cache_config.nbits - self.residual_length = cache_config.residual_length - self.q_group_size = cache_config.q_group_size - self.axis_key = cache_config.axis_key - self.axis_value = cache_config.axis_value - self.compute_dtype = cache_config.compute_dtype - self.device = cache_config.device - - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - if len(self.key_cache) < layer_idx: - raise ValueError( - "QuantizedCache does not support model usage where layers are skipped. Use DynamicCache." - ) - elif len(self.key_cache) == layer_idx: - self._quantized_key_cache.append( - self._quantize(key_states.contiguous(), axis=self.axis_key) - ) - self._quantized_value_cache.append( - self._quantize(value_states.contiguous(), axis=self.axis_value) - ) - self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) - self.value_cache.append( - torch.zeros(0, dtype=key_states.dtype, device=key_states.device) - ) - keys_to_return, values_to_return = key_states, value_states - else: - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) - keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] - values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] - - keys_to_return = torch.cat(keys_to_return, dim=-2) - values_to_return = torch.cat(values_to_return, dim=-2) - if ( - self.key_cache[layer_idx].dim() == 4 - and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length - ): - self._quantized_key_cache[layer_idx] = self._quantize( - keys_to_return.contiguous(), axis=self.axis_key - ) - self._quantized_value_cache[layer_idx] = self._quantize( - values_to_return.contiguous(), axis=self.axis_value - ) - self.key_cache[layer_idx] = torch.zeros( - 0, dtype=key_states.dtype, device=key_states.device - ) - self.value_cache[layer_idx] = torch.zeros( - 0, dtype=key_states.dtype, device=key_states.device - ) - else: - self.key_cache[layer_idx] = torch.cat( - [self.key_cache[layer_idx], key_states], dim=-2 - ) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is - # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def _quantize(self, tensor, axis): - """Quantizes a key/value using a defined quantization method.""" - raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") - - def _dequantize(self, q_tensor): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") - - -class QuantoQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. - - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - - Example: - - ```python - >>> # Run pip install quanto first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4) - >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - QuantoQuantizedCache() - ``` - """ - - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - - if is_optimum_quanto_available(): - optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) - if optimum_quanto_version <= version.parse("0.2.5"): - raise ImportError( - f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." - ) - from optimum.quanto import MaxOptimizer, qint2, qint4 - - if self.nbits not in [2, 4]: - raise ValueError( - f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}" - ) - - if self.axis_key not in [0, -1]: - raise ValueError( - f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}" - ) - - if self.axis_value not in [0, -1]: - raise ValueError( - f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" - ) - - self.qtype = qint4 if self.nbits == 4 else qint2 - self.optimizer = ( - MaxOptimizer() - ) # hardcode as it's the only one for per-channel quantization - - def _quantize(self, tensor, axis): - # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore - if is_optimum_quanto_available(): - from optimum.quanto import quantize_weight - - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) - return qtensor - - def _dequantize(self, qtensor): - return qtensor.dequantize() - - -class HQQQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. - - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - - Example: - - ```python - >>> # Run pip install hqq first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) - >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HQQQuantizedCache() - ``` - """ - - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" - ) - - if self.axis_key not in [0, 1]: - raise ValueError( - f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}" - ) - - if self.axis_value not in [0, 1]: - raise ValueError( - f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}" - ) - - self.quantizer = HQQQuantizer - - def _quantize(self, tensor, axis): - qtensor, meta = self.quantizer.quantize( - tensor, - axis=axis, - device=self.device, - compute_dtype=self.compute_dtype, - nbits=self.nbits, - group_size=self.q_group_size, - ) - meta["compute_dtype"] = self.compute_dtype - self.quantizer.cuda( - qtensor, meta=meta, device=self.device - ) # Move to device and cast to dtype - return qtensor, meta - - def _dequantize(self, qtensor): - quant_tensor, meta = qtensor - tensor = self.quantizer.dequantize(quant_tensor, meta) - return tensor - - -class SinkCache(Cache): - """ - A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to - generate beyond the length of its context window, without losing fluency in the conversation. As it discards past - tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Parameters: - window_length (`int`): - The length of the context window. - num_sink_tokens (`int`): - The number of sink tokens. See the original paper for more information. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - SinkCache() - ``` - """ - - is_sliding = True - - def __init__(self, window_length: int, num_sink_tokens: int) -> None: - super().__init__() - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - self.window_length = window_length - self.num_sink_tokens = num_sink_tokens - self.cos_sin_rerotation_cache = {} - self._cos_cache = None - self._sin_cache = None - self._seen_tokens = ( - 0 # Used in `generate` to keep tally of how many tokens the cache has seen - ) - - @staticmethod - def _rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_key_rotary_pos_emb( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> torch.Tensor: - rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) - return rotated_key_states - - def _get_rerotation_cos_sin( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - if key_states.shape[-2] not in self.cos_sin_rerotation_cache: - # Upcast to float32 temporarily for better accuracy - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence - original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] - shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] - original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] - shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] - rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin - rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - - self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( - rerotation_cos.to(key_states.dtype).unsqueeze(0), - rerotation_sin.to(key_states.dtype).unsqueeze(0), - ) - return self.cos_sin_rerotation_cache[key_states.shape[-2]] - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_max_cache_shape(self) -> int | None: - """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length.""" - return self.window_length - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, - `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the - rotation as the tokens are shifted. - - Return: - A tuple containing the updated key and value states. - """ - # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models - # with partially rotated position embeddings, like Phi or Persimmon. - if cache_kwargs is None: - cache_kwargs = {} - sin = cache_kwargs.get("sin") - cos = cache_kwargs.get("cos") - partial_rotation_size = cache_kwargs.get("partial_rotation_size") - using_rope = cos is not None and sin is not None - - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the sin/cos cache, which holds sin/cos values for all possible positions - if using_rope and layer_idx == 0: - # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove - # after all RoPE models have a llama-like cache utilization. - if cos.dim() == 2: - self._cos_cache = cos - self._sin_cache = sin - elif self._cos_cache is None: - self._cos_cache = cos[0, ...] - self._sin_cache = sin[0, ...] - elif self._cos_cache.shape[0] < self.window_length: - self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) - self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) - - # [bsz, num_heads, seq_len, head_dim] - if len(self.key_cache) <= layer_idx: - # Empty cache - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: - # Growing cache - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) - - else: - # Shifting cache - keys_to_keep = self.key_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : - ] - - # On RoPE models, we need to recompute the Key rotation as the tokens are shifted - if using_rope: - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( - key_states, - self._cos_cache[: self.window_length], - self._sin_cache[: self.window_length], - ) - if partial_rotation_size is not None: - keys_to_keep, keys_pass = ( - keys_to_keep[..., :partial_rotation_size], - keys_to_keep[..., partial_rotation_size:], - ) - keys_to_keep = self._apply_key_rotary_pos_emb( - keys_to_keep, rerotation_cos, rerotation_sin - ) - if partial_rotation_size is not None: - keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) - - # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens - sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] - self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) - - sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] - values_to_keep = self.value_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : - ] - self.value_cache[layer_idx] = torch.cat( - [sink_values, values_to_keep, value_states], dim=-2 - ) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - -class StaticCache(Cache): - """ - Static Cache class to be used with `torch.compile(model)` and `torch.export()`. - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. If you are manually setting the batch size, make sure to take into account the - number of beams if you are running beam search - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - - >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - StaticCache() - ``` - """ - - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None = None, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.float32, - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - super().__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = ( - config.max_position_embeddings if max_cache_len is None else max_cache_len - ) - - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // config.num_attention_heads - ) - - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - self.max_cache_len, - self.head_dim, - ) - device = torch.device(device) if device is not None else None - for idx in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[idx] - else: - layer_device = device - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class SlidingWindowCache(StaticCache): - """ - Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - - indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) - - We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache - - >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - - >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - SlidingWindowCache() - ``` - """ - - is_sliding = True - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None = None, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.float32, - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - max_cache_len = min(config.sliding_window, max_cache_len) - super().__init__( - config=config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - layer_device_map=layer_device_map, - ) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) - if cache_position.shape[0] > self.max_cache_len: - k_out = key_states[:, :, -self.max_cache_len :, :] - v_out = value_states[:, :, -self.max_cache_len :, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones( - self.max_cache_len, dtype=torch.long, device=value_states.device - ).cumsum(0) - cache_position = cache_position.clamp(0, self.max_cache_len - 1) - to_shift = cache_position >= self.max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len - - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - - return k_out, v_out - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def reset(self): - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class EncoderDecoderCache(Cache): - """ - Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and - cross-attention caches. - - Example: - - ```python - >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") - >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") - - >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") - - >>> # Prepare cache classes for encoder and decoder and pass it to model's forward - >>> self_attention_cache = DynamicCache() - >>> cross_attention_cache = DynamicCache() - >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - EncoderDecoderCache() - ``` - - """ - - def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): - super().__init__() - self.self_attention_cache = self_attention_cache - self.cross_attention_cache = cross_attention_cache - self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) - - self.is_updated = {} - for layer_idx in range(len(cross_attention_cache.key_cache)): - self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) - - def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], - ) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.self_attention_cache) - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" - legacy_cache = () - if len(self.cross_attention_cache) > 0: - for self_attn, cross_attn in zip( - self.self_attention_cache.to_legacy_cache(), - self.cross_attention_cache.to_legacy_cache(), - ): - legacy_cache += (self_attn + cross_attn,) - else: - legacy_cache = self.self_attention_cache.to_legacy_cache() - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: tuple[tuple[torch.FloatTensor]] | None = None - ) -> "EncoderDecoderCache": - """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" - cache = cls( - self_attention_cache=DynamicCache(), - cross_attention_cache=DynamicCache(), - ) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx][:2] - cache.self_attention_cache.update(key_states, value_states, layer_idx) - if len(past_key_values[layer_idx]) > 2: - key_states, value_states = past_key_values[layer_idx][2:] - cache.cross_attention_cache.update(key_states, value_states, layer_idx) - cache.is_updated[layer_idx] = True - return cache - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - return self.self_attention_cache.get_seq_length(layer_idx) - - def reset(self): - if hasattr(self.self_attention_cache, "reset"): - self.self_attention_cache.reset() - if hasattr(self.cross_attention_cache, "reset"): - self.cross_attention_cache.reset() - elif not hasattr(self.self_attention_cache, "reset") and not hasattr( - self.cross_attention_cache, "reset" - ): - raise ValueError( - "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " - "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " - f"Got {self.self_attention_cache.__str__()} for the self attention cache and " - f"{self.cross_attention_cache.__str__()} for the cross attention cache." - ) - for layer_idx in self.is_updated: - self.is_updated[layer_idx] = False - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - self.self_attention_cache.reorder_cache(beam_idx) - self.cross_attention_cache.reorder_cache(beam_idx) - - def check_dynamic_cache(self, method: str): - if not ( - isinstance(self.self_attention_cache, DynamicCache) - and isinstance(self.cross_attention_cache, DynamicCache) - ): - raise ValueError( - f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " - f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." - ) - - # TODO(gante, sanchit-gandhi): move following functionality into `.generate` - def crop(self, maximum_length: int): - """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" - self.check_dynamic_cache(self.crop.__name__) - self.self_attention_cache.crop(maximum_length) - - def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - self.check_dynamic_cache(self.batch_split.__name__) - self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) - cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) - - out = [] - for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): - out.append(EncoderDecoderCache(self_attn, cross_attn)) - return out - - @classmethod - def from_batch_splits(cls, splits: list["EncoderDecoderCache"]) -> "EncoderDecoderCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - for idx in range(len(splits[0])): - layer_keys = torch.cat( - [current.self_attention_cache.key_cache[idx] for current in splits], dim=0 - ) - layer_values = torch.cat( - [current.self_attention_cache.value_cache[idx] for current in splits], dim=0 - ) - self_attention_cache.update(layer_keys, layer_values, idx) - - layer_keys = torch.cat( - [current.cross_attention_cache.key_cache[idx] for current in splits], dim=0 - ) - layer_values = torch.cat( - [current.cross_attention_cache.value_cache[idx] for current in splits], dim=0 - ) - cross_attention_cache.update(layer_keys, layer_values, idx) - return cls(self_attention_cache, cross_attention_cache) - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_repeat_interleave.__name__) - self.self_attention_cache.batch_repeat_interleave(repeats) - self.cross_attention_cache.batch_repeat_interleave(repeats) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_select_indices.__name__) - self.self_attention_cache.batch_select_indices(indices) - self.cross_attention_cache.batch_select_indices(indices) - - -class HybridCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention - and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention - and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. - - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ - - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert - # ALL changes from the PR that commented the line below when reactivating it. - # is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None = None, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.float32, - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - self.max_cache_len = max_cache_len - self.max_batch_size = max_batch_size - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // config.num_attention_heads - ) - - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if config.num_key_value_heads is None - else config.num_key_value_heads - ) - - layer_switch = ( - config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 - ) # 2 is for BC - self.is_sliding = torch.tensor( - [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], - dtype=torch.bool, - ) - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - global_cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - max_cache_len, - self.head_dim, - ) - sliding_cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - min(config.sliding_window, max_cache_len), - self.head_dim, - ) - device = torch.device(device) if device is not None and isinstance(device, str) else None - for i in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[i] - else: - layer_device = device - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def _sliding_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - if cache_position.shape[0] > max_cache_len: - k_out = key_states[:, :, -max_cache_len:, :] - v_out = value_states[:, :, -max_cache_len:, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) - cache_position = cache_position.clamp(0, max_cache_len - 1) - to_shift = cache_position >= max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % max_cache_len - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - return k_out, v_out - - def _static_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - sliding_window = cache_kwargs.get("sliding_window") - - # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used - # when the cache is initialized in the forward pass (e.g. Gemma2) - if self.key_cache[layer_idx].device != key_states.device: - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) - if self.value_cache[layer_idx].device != value_states.device: - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if sliding_window: - update_fn = self._sliding_update - else: - update_fn = self._static_update - - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def get_seq_length(self, layer_idx: int | None = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class HybridChunkedCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention - and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention - and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. - - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ - - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert - # ALL changes from the PR that commented the line below when reactivating it. - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None = None, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.bfloat16, - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) - else: - self.sliding_window = config.sliding_window - self.max_cache_len = max_cache_len - self.max_batch_size = max_batch_size - self.head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - self._dtype = dtype - - if hasattr(config.get_text_config(), "no_rope_layers"): - self.is_sliding = config.no_rope_layers - else: - layer_switch = getattr(config, "sliding_window_pattern", 2) - self.is_sliding = [ - bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers) - ] - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] - - def initialise_cache_layer(self, layer_idx, key_states): - if len(self.key_cache) > layer_idx: - return - - num_key_value_heads = key_states.shape[1] - device = key_states.device - global_cache_shape = ( - self.max_batch_size, - num_key_value_heads, - self.max_cache_len, - self.head_dim, - ) - sliding_cache_shape = ( - self.max_batch_size, - num_key_value_heads, - self.sliding_window, - self.head_dim, - ) - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def _sliding_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - cumulative_length = self.cumulative_length[layer_idx] - # Update it now that we saved the value above - self.cumulative_length[layer_idx] += key_states.shape[-2] - is_full = cumulative_length >= max_cache_len - if is_full: - full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) - # Fast decoding path -> here as the effective size is still sliding window, it is extremely important - # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed adress - # in memory (the values are the same as the full states, but not the address!!) - if key_states.shape[-2] == 1: - self.key_cache[layer_idx].copy_(full_key_states) - self.value_cache[layer_idx].copy_(full_value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: - # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) - if cumulative_length == 0: - full_key_states = key_states - full_value_states = value_states - else: - full_key_states = torch.cat( - (k_out[:, :, :cumulative_length, :], key_states), dim=-2 - ) - full_value_states = torch.cat( - (v_out[:, :, :cumulative_length, :], value_states), dim=-2 - ) - else: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) - self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return full_key_states, full_value_states - - def _static_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - self.initialise_cache_layer(layer_idx, key_states) - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if self.is_sliding[layer_idx]: - update_fn = self._sliding_update - else: - update_fn = self._static_update - - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def get_seq_length(self, layer_idx: int | None = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - if len(self.key_cache) == 0: - return 0 - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] - - -class MambaCache: - """ - Cache for mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache - - >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") - - >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values - MambaCache() - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: torch.device | str | None = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. Mamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to( - device=conv_state.device, dtype=conv_state.dtype - ) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - -class OffloadedStaticCache(StaticCache): - """ - Static cache class to be used with `torch.compile(model)` that offloads to the CPU or - another device. - - Args: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize - the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - device (`Union[str, torch.device]`): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*): - The default `dtype` to use when initializing the cache. - offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): - The device to offload to. Defaults to CPU. - layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is splitted between differents gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - """ - - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None, - device: str | torch.device, - dtype: torch.dtype | None = None, - offload_device: str | torch.device = torch.device("cpu"), - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - super(Cache, self).__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = ( - config.max_position_embeddings if max_cache_len is None else max_cache_len - ) - self.device = ( - torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) - ) - self.offload_device = torch.device(offload_device) - self._dtype = dtype if dtype is not None else torch.float32 - - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // config.num_attention_heads - ) - - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - - cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) - - # Create offloaded CPU tensors. - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - for i in range(config.num_hidden_layers): - # First layer is always on-device. - device = self.device if i == 0 else self.offload_device - - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) - - self.key_cache.append(key_cache) - self.value_cache.append(value_cache) - - # Create device tensors. - self._device_key_cache: list[torch.Tensor] = [] - self._device_value_cache: list[torch.Tensor] = [] - - for i in range(2): - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) - - self._device_key_cache.append(key_cache) - self._device_value_cache.append(value_cache) - - # For backwards compatibility. - # TODO(gante): Remove this. - self._seen_tokens = 0 - - # Create new CUDA stream for parallel prefetching. - self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, *optional*): - Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the - `cache_position` input to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - - if layer_idx == 0: - # Update seen tokens. - # TODO(gante): Remove this. - self._seen_tokens += key_states.shape[-2] - - # Always there. - k_out = self.key_cache[0] - v_out = self.value_cache[0] - else: - # Wait for prefetch stream. - if self._prefetch_stream is not None: - torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) - - k_out = self._device_key_cache[layer_idx & 1] - v_out = self._device_value_cache[layer_idx & 1] - - self._prefetch_layer(layer_idx + 1) - - cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - - # Copy the values to the offloaded device as well. - if layer_idx == 0: - self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) - self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does - # explicitly an in-place operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # Copy the values to the offloaded device as well. - if layer_idx != 0: - cache_position = cache_position.to(self.offload_device) - key_states = key_states.to(self.offload_device) - value_states = value_states.to(self.offload_device) - - try: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - self.key_cache[layer_idx][:, :, cache_position] = key_states - self.value_cache[layer_idx][:, :, cache_position] = value_states - - return k_out, v_out - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - - # TODO(gante): Remove this. - return self._seen_tokens - - def get_max_cache_shape(self) -> int | None: - """Returns the maximum sequence length of the cached states.""" - - return self.max_cache_len - - def reset(self) -> None: - """Resets the cache values while preserving the objects.""" - - # For backwards compatibility. - # TODO(gante): Remove this. - self._seen_tokens = 0 - - # Zero out cache. - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address. - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - @property - def seen_tokens(self) -> int: - # For backwards compatibility. - # TODO(gante): Remove this. - return self._seen_tokens - - def _create_key_value_cache_tensors( - self, shape: tuple[int, ...], device: torch.device - ) -> tuple[torch.Tensor, torch.Tensor]: - """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static - addresses for non-CPU tensors. - - Args: - shape (`Tuple[int, ...]`): Shape. - device (`torch.device`): Device. - - Returns: - Key and value cache tensors as a tuple. - """ - - is_cpu_device = device == torch.device("cpu") - - key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) - value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) - - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(key_cache) - torch._dynamo.mark_static_address(value_cache) - - return key_cache, value_cache - - def _prefetch_layer(self, layer_idx: int) -> None: - """Prefetch a layer to the device. Needs to be called in order of layer indices.""" - - # Don't fetch layers that do not exist. - if layer_idx >= len(self.key_cache): - return - - # Alternate between two on-device caches. - if self._prefetch_stream is not None: - with torch.cuda.stream(self._prefetch_stream): - self._prefetch_layer_in_context(layer_idx) - else: - self._prefetch_layer_in_context(layer_idx) - - def _prefetch_layer_in_context(self, layer_idx: int) -> None: - """Performs the actual copy of the layer to device cache.""" - - self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) - self._device_value_cache[layer_idx & 1].copy_( - self.value_cache[layer_idx], non_blocking=True - ) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py deleted file mode 100644 index b17883628f..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py +++ /dev/null @@ -1,289 +0,0 @@ -# coding=utf-8 -# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors -import math -from typing import Callable, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint - -from transformers.cache_utils import Cache -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from transformers.processing_utils import Unpack -from transformers.utils import ( - is_torch_flex_attn_available, - logging, -) -from .transformers_4_51_3__configuration_llama4 import Llama4TextConfig - - -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from transformers.integrations.flex_attention import make_flex_block_causal_mask - -logger = logging.get_logger(__name__) - - -class Llama4TextL2Norm(torch.nn.Module): - def __init__(self, eps: float = 1e-6): - super().__init__() - self.eps = eps - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - return self._norm(x.float()).type_as(x) - - def extra_repr(self): - return f"eps={self.eps}" - - -class Llama4TextRotaryEmbedding(nn.Module): - def __init__(self, config: Llama4TextConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - self.rope_type = "llama3" if config.rope_scaling is not None else "default" - - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - freqs_cis = freqs_cis * self.attention_scaling - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - # print(f"{module.layer_idx=} {module.num_key_value_groups=}") - # print(f"{module.layer_idx=} {module.head_dim=}") - # print(f"{module.layer_idx=} {module.training=}") - # print(f"{scaling=}") - # print(f"{dropout=}") - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class Llama4TextAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Llama4TextConfig, layer_idx, use_rope: bool): # we added use_rope to not be dependent on the layer index - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_attention_heads = config.num_attention_heads - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.num_key_value_heads = config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attn_scale = config.attn_scale - self.floor_scale = config.floor_scale - self.attn_temperature_tuning = config.attn_temperature_tuning - self.attention_dropout = config.attention_dropout - self.is_causal = True - # self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers - self.use_rope = use_rope - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - if self.config.use_qk_norm and self.use_rope: - self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape) - key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - if self.use_rope: # the 16E model skips rope for long context on certain layers - query_states, key_states = apply_rotary_emb( - query_states, key_states, position_embeddings.to(query_states.device) - ) - - if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm - query_states = self.qk_norm(query_states) - key_states = self.qk_norm(key_states) - - # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers - if self.attn_temperature_tuning and not self.use_rope: - device = query_states.device - attn_scales = ( - torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 - ).to(device) - attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 - query_states = (query_states * attn_scales).to(query_states.dtype) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - # print(f"{self.layer_idx=} {cache_position=} {attention_mask=}") - # print(f"{self.layer_idx=} {query_states.flatten()[:10]=}") - # print(f"{self.layer_idx=} {key_states.flatten()[:10]=}") - # print(f"{self.layer_idx=} {value_states.flatten()[:10]=}") - # print(f"{self.layer_idx=} {kwargs=}") - # print(f"{self.layer_idx=} {attention_interface=}") - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py deleted file mode 100644 index 9acc27eb9f..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py +++ /dev/null @@ -1,213 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors -from copy import deepcopy -from typing import Any - -import torch -from transformers.cache_utils import ( - Cache, # used to let GenerationMixin know that we use a Cache object -) - -from .configuration_decilm import DeciLMConfig -from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2 -from .transformers_4_44_2__cache_utils import SinkCache, SlidingWindowCache, StaticCache -from .transformers_4_51_3__cache_utils import HybridChunkedCache - -LayerIndex = tuple[ - int, ... -] # supports both regular transformer blocks and parallel transformer multi-blocks - - -class VariableCache(Cache_4_44_2, Cache): - """ - A Cache object that supports a different Cache implementation for every layer, - including layers without any kv-cache. - Implemented using a list of Cache objects, each represents a "model" with 1 layer. - The default implementation for the layer caches is StaticCache. - The cache of each layer is allocated to the same gpu as the layer itself. - """ - - def __init__( - self, - *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions - config: DeciLMConfig, - batch_size: int | None = None, - max_cache_len: int | None = None, - dtype: torch.dtype = torch.get_default_dtype(), - max_batch_size: int | None = None, - **kwargs, - ) -> None: - Cache_4_44_2.__init__(self) - - self.config = deepcopy(config) - self.max_batch_size = batch_size or max_batch_size - self.batch_size = self.max_batch_size - self.max_cache_len = ( - config.max_position_embeddings if (max_cache_len is None) else max_cache_len - ) - self.dtype = dtype - - self.layer_caches: dict[LayerIndex, Cache_4_44_2] = {} - self.layer_devices: dict[LayerIndex, torch.device] = {} - - def __repr__(self): - return ( - f"VariableCache:\n" - f"==============\n" - f"max_batch_size={self.max_batch_size}\n" - f"batch_size={self.batch_size}\n" - f"max_cache_len={self.max_cache_len}\n" - f"dtype={self.dtype}\n" - f"layer_caches={self.layer_caches}\n" - f"layer_devices={self.layer_devices}\n" - f"==============\n" - ) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int | LayerIndex, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if isinstance(layer_idx, int): - layer_idx = _int_to_layer_index(layer_idx) - - if layer_idx not in self.layer_caches: - self.layer_devices[layer_idx] = key_states.device - self._init_layer_cache(layer_idx) - - layer_cache = self.layer_caches[layer_idx] - assert layer_cache is not None, ( - f"Trying to update the cache of a cache-less layer: {layer_idx=}" - ) - - k_out, v_out = layer_cache.update( - key_states=key_states, value_states=value_states, layer_idx=0, cache_kwargs=cache_kwargs - ) - - input_seq_len = key_states.shape[2] # [batch_size, num_kv_heads, seq_len, hidden_size] - cache_seq_len = self.get_seq_length(layer_idx) - seq_len = max(input_seq_len, cache_seq_len) - - k_out = k_out[:, :, :seq_len, :] - v_out = v_out[:, :, :seq_len, :] - return k_out, v_out - - def _init_layer_cache(self, layer_idx: LayerIndex) -> None: - block_config = self.config.get_block_config(layer_idx) - attention_config = block_config.attention - - if attention_config.no_op or attention_config.replace_with_linear: - return None - - device = self.layer_devices[layer_idx] - assert device is not None, f"Trying to init layer cache for {layer_idx=} without device" - - config = deepcopy(self.config) - config.num_hidden_layers = 1 - config.num_key_value_heads = ( - self.config.num_attention_heads // attention_config.n_heads_in_group - ) - - if attention_config.is_llama4: - attention_chunk_size = attention_config.llama4.attention_chunk_size - is_chunked = attention_chunk_size is not None - config.no_rope_layers = [int(is_chunked)] - config.attention_chunk_size = ( - attention_chunk_size if is_chunked else config.get_min_attention_chunk_size() - ) - self.layer_caches[layer_idx] = HybridChunkedCache( - config=config, - max_batch_size=self.max_batch_size, - max_cache_len=self.max_cache_len, - dtype=self.dtype, - ) - return - - if attention_config.window_length is not None: - if not attention_config.is_sink: - config.sliding_window = attention_config.window_length - self.layer_caches[layer_idx] = SlidingWindowCache( - config=config, - max_batch_size=self.max_batch_size, - max_cache_len=self.max_cache_len, - device=device, - dtype=self.dtype, - ) - return - elif not attention_config.unshifted_sink: - self.layer_caches[layer_idx] = SinkCache( - window_length=attention_config.window_length, - num_sink_tokens=attention_config.num_sink_tokens, - ) - return - - self.layer_caches[layer_idx] = StaticCache( - config=config, - max_batch_size=self.max_batch_size, - max_cache_len=self.max_cache_len, - device=device, - dtype=self.dtype, - ) - - def _get_arbitrary_cache(self) -> Cache_4_44_2: - if len(self.layer_caches) == 0: - raise NoCacheFoundError() - layer_cache = next(iter(self.layer_caches.values())) - return layer_cache - - def get_seq_length(self, layer_idx: int | LayerIndex | None = 0) -> int: - """default 0 to match standard HF implementation""" - if (layer_idx is None) or ( - layer_idx == 0 and _int_to_layer_index(0) not in self.layer_caches - ): - try: - layer_cache = self._get_arbitrary_cache() - return layer_cache.get_seq_length() - except NoCacheFoundError: - return 0 - - if isinstance(layer_idx, int): - layer_idx = _int_to_layer_index(layer_idx) - - layer_cache = self.layer_caches[layer_idx] - return layer_cache.get_seq_length() - - def get_max_length(self) -> int | None: - """Returns the maximum sequence length of the cached states.""" - return self.max_cache_len - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def reset(self): - for layer_idx, layer_cache in self.layer_caches.items(): - if hasattr(layer_cache, "reset"): - layer_cache.reset() - else: - self.layer_caches[layer_idx] = None - self.layer_devices[layer_idx] = None - # self._init_layer_cache(layer_idx) - - -class NoCacheFoundError(Exception): - pass - - -def _int_to_layer_index(layer_idx: int) -> LayerIndex: - return (layer_idx,) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py deleted file mode 100644 index 4c8f86cdbc..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py +++ /dev/null @@ -1,210 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import torch -import torch.nn as nn - - -def _apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -class RotaryEmbedding(nn.Module): - """Original rotary positional embedding.""" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.is_neox_style = is_neox_style - self.dtype = dtype - - cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: int | float) -> torch.Tensor: - """Compute the inverse frequency.""" - # NOTE(woosuk): To exactly match the HF implementation, we need to - # use CPU to compute the cache and then move it to GPU. However, we - # create the cache on GPU for faster initialization. This may cause - # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim) - ) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward_native( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - - -def _yarn_get_mscale(scale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * math.log(scale) + 1.0 - - -# Inverse dim formula to find dim based on number of rotations -def _yarn_find_correction_dim( - num_rotations: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048 -) -> float: - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) - - -def _yarn_find_correction_range( - low_rot: int, high_rot: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048 -) -> tuple[int, int]: - low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def _yarn_linear_ramp_mask(low: float, high: float, dim: int, dtype: torch.dtype) -> torch.Tensor: - if low == high: - high += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -class YaRNScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with YaRN method. - - Credits to Peng et al. github.com/jquesnelle/yarn - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - *, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - ) -> None: - self.scaling_factor = scaling_factor - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - # Get n-d magnitude scaling corrected for interpolation - self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - - def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim - ) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - - low, high = _yarn_find_correction_range( - self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings - ) - # print(f"low: {low}, high: {high}") - # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) - ) * self.extrapolation_factor - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - ) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float32) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * self.mscale - sin = freqs.sin() * self.mscale - cache = torch.cat((cos, sin), dim=-1) - return cache diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py index 73661edba5..f0d5bb0583 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -13,55 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Replacement library for efficiently loading and managing layer-replaced DeciLM models. -- Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations +Replacement library for loading models with layer replacements (AnyModel / sharded HF checkpoints). """ # mypy: ignore-errors import copy import json -import re import tempfile from pathlib import Path from typing import List, Optional -import torch from immutabledict import immutabledict -from lru import LRU from safetensors import safe_open -from safetensors.torch import load_file as safe_load_file -from torch import nn from transformers import PretrainedConfig, PreTrainedModel -import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.anymodel.converter.converter import Converter from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( - DeciLMDecoderLayer, - DeciLMMultiDecoderLayer, - DeciLMRMSNorm, - LMHead, -) from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( extract_block_configs_and_locations, parse_layer_replacement, - sort_replacements, weights_path_to_checkpoint_dir, ) from modelopt.torch.puzzletron.tools.checkpoint_utils import ( - PTH_SUBBLOCKS_DIR_NAME, SAFETENSORS_SUBBLOCKS_DIR_NAME, - infer_weights_dtype, - init_empty_module, - init_module_with_state_dict, load_model_config, ) from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config -from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( - is_in_safetensors_format, - load_and_shard_model, - load_sharded_state_dict, -) +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model class ReplacementLibrary: @@ -78,17 +56,7 @@ def __init__( immutabledict(model_config_overrides) if (model_config_overrides is not None) else None ) - self._loaded_replacements: dict[str, nn.ModuleList] = LRU( - size=256 - ) # least-recently-used dict: a dict of fixed size that evicts old items - - self._dtype = None - - self.teacher_dir = Path(replacement_library_path).parent / "ckpts" / "teacher" self._model_config = None - self._embedding = None - self._ln_f = None - self._lm_head = None self._arbitrary_checkpoint_dir = None @staticmethod @@ -107,17 +75,6 @@ def _ensure_all_checkpoints_are_split(self) -> None: unsplit_checkpoints.append(checkpoint_dir) assert len(unsplit_checkpoints) == 0, f"Found unsplit checkpoints: {unsplit_checkpoints}" - @property - def dtype(self) -> torch.dtype: - if self._dtype is None: - ln_f = self.get_ln_f() - self._dtype = ln_f.weight.dtype - return self._dtype - - @property - def n_layer(self) -> int: - return self.model_config.get_num_hidden_layers() - @property def model_config(self) -> DeciLMConfig: if self._model_config is None: @@ -137,7 +94,7 @@ def create_model_config(self, layer_replacements: list[dict]): model_config.num_hidden_layers = len(block_configs) return model_config - def _get_arbitrary_block_checkpoint_paths(self): + def _get_arbitrary_non_block_checkpoint_paths(self): checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name] @@ -161,7 +118,7 @@ def prepare_tmp_checkpoint_dir( ): arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) - weight_paths = self._get_arbitrary_block_checkpoint_paths() + weight_paths = self._get_arbitrary_non_block_checkpoint_paths() for layer_replacement in layer_replacements: weight_paths += layer_replacement["weight_paths"] @@ -194,194 +151,11 @@ def load_model( model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir) return model - def load_checkpoint(self, checkpoint_dir: str | Path) -> PreTrainedModel: - checkpoint_dir = Path(checkpoint_dir).resolve() - layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) - model = self.load_model(layer_replacements) - return model - - def _locate_replacements_of_entire_checkpoint(self, checkpoint_dir: str | Path) -> list[dict]: - weight_paths_located = [] - layer_replacements = [] - for layer_replacement in self.replacement_library: - weight_paths = layer_replacement["weight_paths"] - weight_paths = [Path(p).absolute().resolve() for p in weight_paths] - layer_replacement["weight_paths"] = weight_paths - if len(weight_paths) > 0 and all( - p.is_relative_to(checkpoint_dir) for p in weight_paths - ): - layer_replacements.append(layer_replacement) - weight_paths_located.extend(weight_paths) - - all_block_weight_paths = [ - p - for p in list((checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).iterdir()) - if p.name not in ("embeddings.safetensors", "lm_head.safetensors") - ] - missing_paths = set(all_block_weight_paths) - set(weight_paths_located) - assert len(missing_paths) == 0, ( - f"Couldn't locate replacements for the entire checkpoint {checkpoint_dir}, missing weights: {missing_paths}" - ) - - dedupped_layer_replacements = [] - for weights_path in all_block_weight_paths: - replacements_with_path = [ - rep for rep in layer_replacements if weights_path in rep["weight_paths"] - ] - largets_replacement_with_path = max( - replacements_with_path, key=lambda rep: len(rep["weight_paths"]) - ) - if largets_replacement_with_path not in dedupped_layer_replacements: - dedupped_layer_replacements.append(largets_replacement_with_path) - - dedupped_layer_replacements = sort_replacements(dedupped_layer_replacements) - return dedupped_layer_replacements - - def get_block( - self, layer_replacement: dict, block_idx_in_replacement: int - ) -> DeciLMDecoderLayer | DeciLMMultiDecoderLayer: - if str(layer_replacement) not in self._loaded_replacements.keys(): - self._loaded_replacements[str(layer_replacement)] = self._load_layer_replacement( - layer_replacement - ) - module_list = self._loaded_replacements[str(layer_replacement)] - block = module_list[block_idx_in_replacement] - return block - - def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: - state_dict = dict() - for weights_path in layer_replacement["weight_paths"]: - if weights_path.suffix == ".safetensors": - curr_state_dict = safe_load_file(weights_path) - elif weights_path.suffix == ".pth": - curr_state_dict = torch.load(weights_path, weights_only=True) - else: - raise ValueError(f"Unrecognized suffix of {weights_path=}") - for param_name in curr_state_dict.keys(): - assert param_name not in state_dict, ( - f"Duplicate entries for {param_name=} in {layer_replacement=}" - ) - state_dict.update(curr_state_dict) - - if len(state_dict) > 0: - block_indices = [ - int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0]) - for param_name in state_dict.keys() - ] - assert sorted(set(block_indices)) == list( - range(min(block_indices), max(block_indices) + 1) - ), ( - f"Block indices in loaded weight files must be consecutive, but found {sorted(set(block_indices))} in {layer_replacement=}" - ) - - min_block_idx = min(block_indices) - - state_dict = { - param_name.replace( - f"model.layers.{block_idx}.", f"{block_idx - min_block_idx}." - ): param_weight - for block_idx, (param_name, param_weight) in zip(block_indices, state_dict.items()) - } - - dtype = infer_weights_dtype(state_dict) - model_config = copy.deepcopy(self.model_config) - model_config.block_configs = layer_replacement["child_block_configs"] - model_config.num_hidden_layers = len(layer_replacement["child_block_configs"]) - - module_list = nn.ModuleList( - [ - ( - init_empty_module(DeciLMDecoderLayer, dtype, model_config, layer_idx) - if (block_config.parallel_blocks is None) - else init_empty_module(DeciLMMultiDecoderLayer, dtype, model_config, layer_idx) - ) - for layer_idx, block_config in enumerate(layer_replacement["child_block_configs"]) - ] - ) - - module_list.load_state_dict(state_dict, strict=True) - return module_list - - def _move_inactive_blocks_to_cpu(self, active_blocks: list[nn.Module]) -> None: - for module_list in self._loaded_replacements.values(): - for module in module_list: - if module not in active_blocks: - module.to("cpu") - - def get_embedding(self) -> nn.Embedding: - if self._embedding is None: - state_dict = { - "weight": self._get_arbitrary_non_block_param( - self.model_config.get_embedding_layer_name() + ".weight" - ) - } - self._embedding = init_module_with_state_dict( - state_dict, - nn.Embedding, - num_embeddings=self.model_config.vocab_size, - embedding_dim=self.model_config.hidden_size, - ) - return self._embedding - - def get_ln_f(self) -> DeciLMRMSNorm: - if self._ln_f is None: - state_dict = { - "weight": self._get_arbitrary_non_block_param( - self.model_config.get_final_layer_norm_layer_name() + ".weight" - ) - } - self._ln_f = init_module_with_state_dict( - state_dict, - DeciLMRMSNorm, - hidden_size=self.model_config.hidden_size, - eps=self.model_config.rms_norm_eps, - ) - return self._ln_f - - def get_lm_head(self) -> nn.Linear: - if self._lm_head is None: - state_dict = { - "weight": self._get_arbitrary_non_block_param( - self.model_config.get_lm_head_layer_name() + ".weight" - ) - } - self._lm_head = init_module_with_state_dict( - state_dict, - LMHead, - out_features=self.model_config.vocab_size, - in_features=self.model_config.hidden_size, - bias=False, - ) - return self._lm_head - - def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor: - checkpoint_dir = self.get_arbitrary_checkpoint_dir() - if ( - is_in_safetensors_format(checkpoint_dir) - or (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists() - ): - partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name]) - return partial_state_dict[param_name] - - non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth" - assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir) - non_block_state_dict = torch.load(non_block_pth_path) - return non_block_state_dict[param_name] - def get_arbitrary_checkpoint_dir(self) -> Path: if self._arbitrary_checkpoint_dir is None: self._arbitrary_checkpoint_dir = self._get_arbitrary_checkpoint_dir() return self._arbitrary_checkpoint_dir - def get_teacher_dir(self) -> Path: - return self.teacher_dir - - def get_teacher_lm_head_path(self) -> Path: - return self.get_teacher_dir() / SAFETENSORS_SUBBLOCKS_DIR_NAME / "lm_head.safetensors" - - def get_teacher_embedding_path(self) -> Path: - return self.get_teacher_dir() / SAFETENSORS_SUBBLOCKS_DIR_NAME / "embeddings.safetensors" - def _get_arbitrary_checkpoint_dir(self) -> Path: for layer_replacement in self.replacement_library: weight_paths = layer_replacement["weight_paths"] @@ -396,27 +170,3 @@ def _get_all_checkpoint_dirs(self) -> list[Path]: checkpoint_dir = weights_path_to_checkpoint_dir(weights_path) checkpoint_dirs.add(checkpoint_dir) return list(checkpoint_dirs) - - -def _error_message_ensure_split(checkpoint_dir: Path) -> str: - return ( - f"Encountered unsplit checkpoint dir '{checkpoint_dir}', " - f"please call `ensure_all_checkpoints_are_split`" - ) - - -def _get_owned_block_indexes(n_layer: int) -> list[int]: - last_process_blocks = np.array([n_layer - 1]) # less params in last gpu, leave room for logits - - if dist.size() == 1: - # Only one process: assign everything (including the "last process" block) to rank 0 - owned_block_indexes_per_process = [ - np.concatenate([np.arange(n_layer - 1), last_process_blocks]) - ] - else: - # Multiple processes: split n_layer-1 blocks, reserve the last for "last process" - owned_block_indexes_per_process = np.array_split(range(n_layer - 1), dist.size() - 1) - owned_block_indexes_per_process.append(last_process_blocks) - - owned_block_indexes = owned_block_indexes_per_process[dist.rank()].tolist() - return owned_block_indexes diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_utils.py b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py index 68ba0b5fc3..269e5e63ea 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_utils.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py @@ -21,8 +21,9 @@ from copy import deepcopy from pathlib import Path +from transformers import PretrainedConfig + from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.mip.utils import sort_replacements @@ -73,7 +74,7 @@ def weights_path_to_checkpoint_dir(weights_path: Path) -> Path: def replacement_is_teacher( layer_replacement: dict, - teacher_model_config: DeciLMConfig, + teacher_model_config: PretrainedConfig, teacher_checkpoint_dir: Path, ) -> bool: paths_all_teacher = all( @@ -86,7 +87,7 @@ def replacement_is_teacher( def is_replacement_identical_to_teacher( layer_replacement: dict, - teacher_model_config: DeciLMConfig, + teacher_model_config: PretrainedConfig, ) -> bool: if len(layer_replacement["parent_layer_indices"]) == 1: block_idx = layer_replacement["parent_layer_indices"][0] @@ -109,7 +110,7 @@ def is_replacement_identical_to_teacher( def split_replacements_to_teacher_and_student( replacements: list[dict], - teacher_model_config: DeciLMConfig, + teacher_model_config: PretrainedConfig, teacher_checkpoint_dir: Path, ) -> tuple[list[dict], list[dict]]: teacher_replacements, student_replacements = [], [] From 0708ca23ba5879c3107bcba16c24106b8855fdf7 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 24 Mar 2026 08:40:59 +0100 Subject: [PATCH 48/62] Dkorzekwa/anymodel subblock stats (#1085) ### What does this PR do? Integration tests for subblock stats (memory + num_of_params) ## Summary by CodeRabbit * **Refactor** * Improved teacher model configuration loading and statistics computation for enhanced accuracy in memory and parameter calculations. * **Tests** * Added comprehensive tests validating teacher model memory and parameter statistics calculations with strict accuracy tolerances. --------- Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/mip/sweep.py | 151 ++++++++---------- tests/gpu/torch/puzzletron/test_puzzletron.py | 138 ++++++++++------ 2 files changed, 163 insertions(+), 126 deletions(-) diff --git a/modelopt/torch/puzzletron/mip/sweep.py b/modelopt/torch/puzzletron/mip/sweep.py index 82046934bc..82d9b11e12 100644 --- a/modelopt/torch/puzzletron/mip/sweep.py +++ b/modelopt/torch/puzzletron/mip/sweep.py @@ -17,50 +17,44 @@ import json from pathlib import Path +from typing import Any +from omegaconf import DictConfig, OmegaConf +from transformers import PretrainedConfig + +import modelopt.torch.puzzletron.anymodel.models # noqa: F401 — register ModelDescriptorFactory entries import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import ( + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.mip.run_puzzle import _get_block_stats, filter_subblock_stats_by_args +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config from modelopt.torch.puzzletron.tools.logger import mprint -def get_teacher_memory_from_subblock_stats(hydra_cfg) -> float: - """Calculate teacher model memory from subblock_stats.json. - - Replicates the MIP solver's memory calculation logic: - - Loads subblock_stats.json which contains memory measurements for all subblock configs - - Finds the teacher FFN subblock (with full intermediate_size) - - Finds the teacher Attention subblock (full attention, not no_op) - - Calculates: non_block_memory + (ffn_memory + attention_memory) * num_layers - - This matches how the MIP solver computes total model memory via _get_block_stats(). - - Args: - hydra_cfg: Hydra configuration object - - Returns: - Total teacher memory in MiB - """ +def _load_teacher_subblock_stats(hydra_cfg: DictConfig) -> tuple[dict[str, Any], PretrainedConfig]: + """Load filtered subblock_stats and teacher ``model_config`` for the current MIP scenario.""" puzzle_dir = Path(hydra_cfg.puzzle_dir) - - # Read config.json directly from the teacher model path teacher_dir = Path(hydra_cfg.teacher_dir) - config_file = teacher_dir / "config.json" - - with open(config_file) as f: - config_dict = json.load(f) - num_layers = config_dict["num_hidden_layers"] - teacher_ffn_intermediate = config_dict["intermediate_size"] - teacher_num_kv_heads = config_dict["num_key_value_heads"] + descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code) + lm_config = descriptor.get_language_model_config(model_config) + hidden_size = lm_config.hidden_size - # Get the MIP configuration mip_subblock_args = hydra_cfg.mip.subblock_stats_args[0] - batch_size = mip_subblock_args["batch_size"] - weights_dtype = str(mip_subblock_args["weights_dtype"]) - activations_dtype = str(mip_subblock_args["activations_dtype"]) - kv_cache_dtype = str(mip_subblock_args["kv_cache_dtype"]) + subblock_stats_args = OmegaConf.to_container(mip_subblock_args, resolve=True) + # Subblock_stats.json can list multiple runs that share batch/dtypes but differ by hidden size; + # filter_subblock_stats_by_args needs n_embd so exactly one row matches the teacher. + subblock_stats_args = {**subblock_stats_args, "n_embd": hidden_size} + + batch_size = subblock_stats_args["batch_size"] + weights_dtype = str(subblock_stats_args["weights_dtype"]) + activations_dtype = str(subblock_stats_args["activations_dtype"]) + kv_cache_dtype = str(subblock_stats_args["kv_cache_dtype"]) - # Load subblock_stats.json subblock_stats_path = puzzle_dir / "subblock_stats.json" if not subblock_stats_path.exists(): raise FileNotFoundError( @@ -71,69 +65,64 @@ def get_teacher_memory_from_subblock_stats(hydra_cfg) -> float: with open(subblock_stats_path) as f: subblock_stats_list = json.load(f) - # Find the entry matching our MIP configuration and teacher's n_embd - matching_stats = None - for stats_entry in subblock_stats_list: - args = stats_entry["args"] - if ( - args["batch_size"] == batch_size - and args["weights_dtype"] == weights_dtype - and args["activations_dtype"] == activations_dtype - and args["kv_cache_dtype"] == kv_cache_dtype - and args.get("n_embd") == config_dict["hidden_size"] - ): - matching_stats = stats_entry - break - - if matching_stats is None: + try: + subblock_stats = filter_subblock_stats_by_args(subblock_stats_list, subblock_stats_args) + except AssertionError as e: raise ValueError( - f"No subblock_stats entry found for batch_size={batch_size}, " + f"No unique subblock_stats entry for batch_size={batch_size}, " f"dtypes=({weights_dtype}, {activations_dtype}, {kv_cache_dtype}), " - f"n_embd={config_dict['hidden_size']}" - ) + f"n_embd={hidden_size}" + ) from e - # Get non-block memory (embeddings, LM head, etc.) - total_memory = matching_stats.get("non_block", {}).get("memory_mib", 0.0) + return subblock_stats, model_config - # Find the teacher FFN and Attention subblocks - # Note: Each subblock is EITHER attention OR ffn, not both - # We need to find BOTH and add their memory together - teacher_ffn_subblock = None - teacher_attention_subblock = None - for subblock in matching_stats.get("subblocks", []): - subblock_class = subblock.get("subblock_config_class", "") - subblock_config = subblock.get("subblock_config", {}) +def get_teacher_memory_from_subblock_stats(hydra_cfg: DictConfig) -> float: + """Calculate teacher model memory from subblock_stats.json. - # Check for FFN subblocks with teacher's intermediate_size - if "FFN" in subblock_class: - ffn_size = subblock_config.get("intermediate_size") - if ffn_size == teacher_ffn_intermediate and not subblock_config.get("no_op", False): - teacher_ffn_subblock = subblock + Sums ``non_block`` and per-layer ``_get_block_stats(subblock_stats, block_config, layer_index)`` + over ``model_config.block_configs``, matching :func:`run_puzzle._get_block_stats`. - # Check for Attention subblocks with teacher's num_key_value_heads - elif "Attention" in subblock_class: - kv_heads = subblock_config.get("num_key_value_heads") - if kv_heads == teacher_num_kv_heads and not subblock_config.get("no_op", False): - teacher_attention_subblock = subblock + Args: + hydra_cfg: Hydra configuration object - if teacher_ffn_subblock is None: - raise ValueError( - f"Could not find teacher FFN subblock with intermediate_size={teacher_ffn_intermediate}" - ) + Returns: + Total teacher memory in MiB + """ + subblock_stats, model_config = _load_teacher_subblock_stats(hydra_cfg) - if teacher_attention_subblock is None: - raise ValueError( - f"Could not find teacher Attention subblock with num_key_value_heads={teacher_num_kv_heads}" - ) + total_memory = subblock_stats.get("non_block", {}).get("memory_mib", 0.0) - # Calculate total teacher memory: non_block + (ffn_memory + attention_memory) * num_layers - per_layer_memory = teacher_ffn_subblock["memory_mib"] + teacher_attention_subblock["memory_mib"] - total_memory += per_layer_memory * num_layers + for layer_idx, block_config in enumerate(model_config.block_configs): + block_stats = _get_block_stats(subblock_stats, block_config, layer_idx) + total_memory += block_stats["memory_mib"] return total_memory +def get_teacher_num_params_from_subblock_stats(hydra_cfg: DictConfig) -> int: + """Calculate total teacher parameter count from subblock_stats.json. + + Sums ``non_block`` and per-layer ``_get_block_stats(...)["num_params"]`` over + ``model_config.block_configs``, matching :func:`run_puzzle._get_block_stats`. + + Args: + hydra_cfg: Hydra configuration object + + Returns: + Total teacher parameter count (same units as subblock_stats JSON). + """ + subblock_stats, model_config = _load_teacher_subblock_stats(hydra_cfg) + + total_params = subblock_stats.get("non_block", {}).get("num_params", 0) + + for layer_idx, block_config in enumerate(model_config.block_configs): + block_stats = _get_block_stats(subblock_stats, block_config, layer_idx) + total_params += block_stats["num_params"] + + return int(total_params) + + def extract_solution_results( solution_path: Path, target_memory_mib: float, diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index fa9e5281dc..f3f49bed27 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -27,6 +27,10 @@ import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron import puzzletron from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.mip.sweep import ( + get_teacher_memory_from_subblock_stats, + get_teacher_num_params_from_subblock_stats, +) # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) # using a one-click command. @@ -106,7 +110,7 @@ def _test_puzzletron_multiprocess_job( dist.barrier() # Compress the model using a one-click approach - puzzletron.puzzletron( + hydra_cfg = puzzletron.puzzletron( str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path) ) @@ -157,7 +161,7 @@ def _test_puzzletron_multiprocess_job( # assertions for the build_library_and_stats step 4 assert (puzzle_dir / "replacement_library.json").is_file() - assert (puzzle_dir / "subblock_stats.json").is_file() + _assert_subblock_stats_anymodel(hf_model_name, hydra_cfg) # assertions for the scoring step 5 solution_0_filepath = ( @@ -173,50 +177,20 @@ def _test_puzzletron_multiprocess_job( ) -# Expected pruning activation values per model -# Each model has a list of (score, channels) tuples for each FFN layer -EXPECTED_PRUNING_VALUES = { - "meta-llama/Llama-3.1-8B-Instruct": [ - {"score": 73, "channels": 95}, - {"score": 440, "channels": 174}, - ], - "meta-llama/Llama-3.2-3B-Instruct": [ - {"score": 79, "channels": 95}, - {"score": 428, "channels": 174}, - ], - "mistralai/Mistral-Small-24B-Instruct-2501": [ - {"score": 73, "channels": 95}, - {"score": 431, "channels": 174}, - ], - # NemotronH with pattern "*-" has only 1 FFN layer (the "-" layer) - "nvidia/NVIDIA-Nemotron-Nano-12B-v2": [ - {"score": 70, "channels": 509}, - ], - # nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 uses MoE expert pruning, not FFN pruning - "Qwen/Qwen2.5-7B-Instruct": [ - {"score": 96, "channels": 433}, - {"score": 485, "channels": 105}, - ], - "Qwen/Qwen3-8B": [ - {"score": 208, "channels": 51}, - {"score": 475, "channels": 266}, - ], -} - +def _assert_subblock_stats_anymodel(hf_model_name: str, hydra_cfg) -> None: + """Minimal subblock_stats checks and teacher memory / param regression values.""" + assert (Path(hydra_cfg.puzzle_dir) / "subblock_stats.json").is_file() + teacher_mem_mib = get_teacher_memory_from_subblock_stats(hydra_cfg) + teacher_num_params = get_teacher_num_params_from_subblock_stats(hydra_cfg) -# Expected lm_loss values per model -EXPECTED_LM_LOSS = { - "meta-llama/Llama-3.1-8B-Instruct": 4.706878662109375, - "meta-llama/Llama-3.2-3B-Instruct": 4.816886901855469, - "mistralai/Mistral-Small-24B-Instruct-2501": 4.709150314331055, - # TODO: not reproducible in CI, skipping for now - # "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 4.7737884521484375, - "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 4.79390811920166, - # "openai/gpt-oss-20b": 4.689250946044922, - "Qwen/Qwen2.5-7B-Instruct": 4.778186798095703, - "Qwen/Qwen3-8B": 4.733874320983887, - "Qwen/Qwen3-VL-30B-A3B-Instruct": 4.65625, -} + assert abs(teacher_mem_mib - EXPECTED_TEACHER_MEMORY_MIB[hf_model_name]) < 1e-6, ( + f"Teacher memory mismatch for {hf_model_name}: " + f"expected {EXPECTED_TEACHER_MEMORY_MIB[hf_model_name]}, got {teacher_mem_mib}" + ) + assert abs(teacher_num_params - EXPECTED_TEACHER_NUM_PARAMS[hf_model_name]) < 1e-6, ( + f"Teacher num_params mismatch for {hf_model_name}: " + f"expected {EXPECTED_TEACHER_NUM_PARAMS[hf_model_name]}, got {teacher_num_params}" + ) def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): @@ -291,3 +265,77 @@ def _assert_mip_solutions(puzzle_dir: Path, hf_model_name: str): # Validate lm_loss _assert_lm_loss(puzzle_dir, hf_model_name) + + +# Expected pruning activation values per model +# Each model has a list of (score, channels) tuples for each FFN layer +EXPECTED_PRUNING_VALUES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + {"score": 73, "channels": 95}, + {"score": 440, "channels": 174}, + ], + "meta-llama/Llama-3.2-3B-Instruct": [ + {"score": 79, "channels": 95}, + {"score": 428, "channels": 174}, + ], + "mistralai/Mistral-Small-24B-Instruct-2501": [ + {"score": 73, "channels": 95}, + {"score": 431, "channels": 174}, + ], + # NemotronH with pattern "*-" has only 1 FFN layer (the "-" layer) + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": [ + {"score": 70, "channels": 509}, + ], + # nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 uses MoE expert pruning, not FFN pruning + "Qwen/Qwen2.5-7B-Instruct": [ + {"score": 96, "channels": 433}, + {"score": 485, "channels": 105}, + ], + "Qwen/Qwen3-8B": [ + {"score": 208, "channels": 51}, + {"score": 475, "channels": 266}, + ], +} + + +# Expected lm_loss values per model +EXPECTED_LM_LOSS = { + "meta-llama/Llama-3.1-8B-Instruct": 4.706878662109375, + "meta-llama/Llama-3.2-3B-Instruct": 4.816886901855469, + "mistralai/Mistral-Small-24B-Instruct-2501": 4.709150314331055, + # TODO: not reproducible in CI, skipping for now + # "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 4.7737884521484375, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 4.79390811920166, + "openai/gpt-oss-20b": 4.689250946044922, + "Qwen/Qwen2.5-7B-Instruct": 4.778186798095703, + "Qwen/Qwen3-8B": 4.733874320983887, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 4.65625, +} + + +# Expected teacher memory from subblock_stats (MiB) +EXPECTED_TEACHER_MEMORY_MIB = { + "meta-llama/Llama-3.1-8B-Instruct": 386.22705078125, + "meta-llama/Llama-3.2-3B-Instruct": 386.22705078125, + "mistralai/Mistral-Small-24B-Instruct-2501": 386.22705078125, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 552.47607421875, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 193.16357421875, + "openai/gpt-oss-20b": 456.75830078125, + "Qwen/Qwen2.5-7B-Instruct": 386.22705078125, + "Qwen/Qwen3-8B": 386.22705078125, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 420.74267578125, +} + + +# Expected total teacher params from subblock_stats +EXPECTED_TEACHER_NUM_PARAMS = { + "meta-llama/Llama-3.1-8B-Instruct": 1167616.0, + "meta-llama/Llama-3.2-3B-Instruct": 1167616.0, + "mistralai/Mistral-Small-24B-Instruct-2501": 1167616.0, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 188993280.0, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 610048.0, + "openai/gpt-oss-20b": 38146304.0, + "Qwen/Qwen2.5-7B-Instruct": 1167616.0, + "Qwen/Qwen3-8B": 1167616.0, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 19263744.0, +} From 3193f30566a205c02c9e2c54adefdba7287d79a7 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 24 Mar 2026 12:17:00 +0100 Subject: [PATCH 49/62] Dkorzekwa/anymodel subblock stats nodecilm (#1102) ### What does this PR do? Refactoring of subblock stats to stop using DeciLM code and use anymodel instead. ## Summary by CodeRabbit * **Refactor** * Restructured internal estimation logic for model memory and parameters to support broader model architectures. * Updated model initialization utilities. * **Tests** * Updated baseline metrics for validation tests. --------- Signed-off-by: Daniel Korzekwa --- .../decilm/deci_lm_hf_code/modeling_decilm.py | 105 +------- .../calc_subblock_params_and_memory.py | 232 +++++++++--------- .../subblock_stats/calc_subblock_stats.py | 17 +- .../puzzletron/tools/checkpoint_utils_hf.py | 31 ++- tests/gpu/torch/puzzletron/test_puzzletron.py | 38 +-- 5 files changed, 185 insertions(+), 238 deletions(-) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py index 0102fc3a95..915c111be5 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py @@ -18,14 +18,14 @@ # Pared-down DeciLM building blocks for Model-Optimizer puzzletron / AnyModel flows. # The full HF DeciLM decoder stack (decoder layers, attention, rope, etc.) is not vendored here; # AnyModel loads real models via transformers. This module keeps shared helpers: RMSNorm, -# gated/vanilla MLP (used by MoE accounting), MoE, and LMHead for replacement / validation code. +# gated MLP, and LMHead for replacement / validation code. # mypy: ignore-errors import torch import torch.nn.functional as F from torch import nn -from .block_config import FFNConfig, MoEConfig +from .block_config import FFNConfig from .configuration_decilm import DeciLMConfig from .transformers_4_44_2__activations import ACT2FN @@ -102,107 +102,6 @@ def forward(self, x): return down_proj -class DeciLMMoe(nn.Module): - """ - Implementation of Mixture of Experts module for DeciLM. - Equivalent to Llama4 MoE but implemented more frugally. - """ - - def __init__(self, config: DeciLMConfig, ffn_config: FFNConfig): - super().__init__() - self.config = config - self.ffn_config = ffn_config - - # MoE parameters - assert ffn_config.moe is not None, "MoE configuration must be provided to use DeciLMMoe" - self.moe_config: MoEConfig = ffn_config.moe - self.hidden_dim = config.hidden_size - self.num_experts_per_tok = self.moe_config.num_experts_per_tok - self.num_local_experts = self.moe_config.num_local_experts - self.expert_intermediate_dim = self.moe_config.expert_intermediate_dim - self.shared_expert_intermediate_dim = self.moe_config.shared_expert_intermediate_dim - - # Initialize experts and router - routed_expert_ffn_config = FFNConfig( - intermediate_size=self.expert_intermediate_dim, - ) - - self.experts = nn.ModuleList( - [ - DeciLMGatedMLP(config, routed_expert_ffn_config) - for _ in range(self.num_local_experts) - ] - ) - - self.router = nn.Linear(config.hidden_size, self.num_local_experts, bias=False) - - # Initialize shared expert as a standard MLP - shared_expert_ffn_config = FFNConfig( - intermediate_size=self.moe_config.shared_expert_intermediate_dim - ) - self.shared_expert = DeciLMGatedMLP(config, shared_expert_ffn_config) - - if ffn_config.sparsify is not None: - self.register_full_backward_hook(sparsity_backward_hook) - - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass through the MoE layer. - - Args: - hidden_states (torch.Tensor): Input tensor of shape (batch, seq_len, hidden_dim) - - Returns: - tuple: - - torch.Tensor: Output tensor of shape (batch, seq_len, hidden_dim) - - torch.Tensor: Router scores for loss computation - """ - router_logits = self.router(hidden_states) - - routed_out = self.forward_routed_experts(hidden_states, router_logits) - - shared_out = self.shared_expert(hidden_states) - - moe_out = routed_out + shared_out - - return moe_out, router_logits - - def forward_routed_experts( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor - ) -> torch.Tensor: - """ - For each expert: - 1. Build the input to the expert based on the router mask - 2. Run the expert - 3. Add the result of the expert into the total MoE result using += - """ - router_top_values, router_indices = torch.topk( - router_logits, self.num_experts_per_tok, dim=-1 - ) - router_scores = torch.sigmoid(router_top_values.float()).to(hidden_states.dtype) - - routed_out = torch.zeros_like(hidden_states) - for i_expert in range(self.num_local_experts): - expert_mask = router_indices == i_expert - if expert_mask.any(): - is_token_routed_to_this_expert = expert_mask.any(dim=-1) - relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] - relevant_scores = router_scores[expert_mask] - expert_in = relevant_hidden_states * relevant_scores.unsqueeze(-1) - - expert_out = self.experts[i_expert](expert_in).to(hidden_states.device) - - routed_out[is_token_routed_to_this_expert, :] += expert_out - - return routed_out - - def extra_repr(self) -> str: - return ( - f"(MoE): num_local_experts={self.num_local_experts}, " - f"expert_intermediate_dim={self.expert_intermediate_dim}," - ) - - class LMHead(nn.Linear): """ Special class to allow FSDP wrapping without affecting other Linear layers in the model. diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py index 88081d1773..a93e40978f 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py @@ -21,21 +21,27 @@ considering various data types, batch sizes, and sequence lengths. """ +import copy import json import math from pathlib import Path +from typing import Type import numpy as np import torch +from transformers import PretrainedConfig +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, + BlockConfig, FFNConfig, MambaConfig, + maybe_cast_block_configs, ) -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMMoe +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import init_model_from_config from modelopt.torch.puzzletron.utils.utils import ( + EmptyInitOnDevice, calculate_kv_dim, raise_unknown_subblock_config_error, sizeof_dtype, @@ -53,21 +59,34 @@ def calculate_subblock_memory( weights_dtype: torch.dtype, kv_cache_dtype: torch.dtype, allocate_prefill_query: bool, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], ) -> float | dict[str, float]: + """``model_config`` / ``descriptor`` are required (puzzletron-style); FFN uses them for meta init.""" if subblock_config.no_op: return 0 - if subblock_config.replace_with_linear: - return calculate_linear_memory(n_embd, weights_dtype) if isinstance(subblock_config, FFNConfig): - return calculate_ffn_memory(subblock_config, n_embd, weights_dtype) + return calculate_ffn_memory( + subblock_config, + model_config, + descriptor, + weights_dtype, + ) if isinstance(subblock_config, AttentionConfig): if subblock_config.is_mamba: return calculate_mamba_memory( - subblock_config.mamba, n_embd, batch_size, weights_dtype, kv_cache_dtype + subblock_config, + model_config, + descriptor, + batch_size, + weights_dtype, + kv_cache_dtype, ) else: return calculate_attention_memory( subblock_config, + model_config, + descriptor, batch_size, prefill_seq_len, generation_seq_len, @@ -82,38 +101,94 @@ def calculate_subblock_memory( def calculate_subblock_params( - subblock_config: FFNConfig | AttentionConfig, - n_embd: int, - n_head: int, + config: PretrainedConfig, + layer_config: BlockConfig | FFNConfig | AttentionConfig, + descriptor: Type[ModelDescriptor], ) -> int: - if subblock_config.no_op: + """Count parameters on one meta decoder layer (puzzletron ``calculate_subblock_params`` parity). + + Unlike ``puzzletron_patcher`` during ``__init__``, we do **not** use ``deci_x_patcher`` here: + for models such as GPT-OSS, Transformers ``post_init`` validates ``_keep_in_fp32_modules`` + against the module tree; replacing norms / attn / mlp with no-op placeholders **before** + ``post_init`` raises (e.g. ``post_attention_layernorm`` … not part of the modules). + + With ``num_hidden_layers == 1`` we merge ``block_config_to_layer_overrides`` into the LM config + (what the patcher would pass into ``DecoderLayer.__init__``), build a stock layer, run + ``post_init``, then apply ``attn_no_op_post_init`` / ``mlp_no_op_post_init`` for param counting. + """ + if isinstance(layer_config, FFNConfig): + block_config = layer_config.to_blockconfig() + elif isinstance(layer_config, AttentionConfig): + block_config = layer_config.to_blockconfig() + else: + block_config = layer_config + + ffn = block_config.ffn + attn = block_config.attention + ffn_no_op = ffn is None or ffn.no_op + attn_no_op = attn is None or attn.no_op + if not (ffn_no_op or attn_no_op): + raise AssertionError( + "One of ffn or attention must be no-op for sublayer param calculation " + "(single subblock at a time)." + ) + if ffn_no_op and attn_no_op: return 0 - if subblock_config.replace_with_linear: - return calculate_linear_params(n_embd) - if isinstance(subblock_config, FFNConfig): - return calculate_ffn_params(subblock_config, n_embd) - if isinstance(subblock_config, AttentionConfig): - if subblock_config.is_mamba: - return calculate_mamba_params(subblock_config.mamba, n_embd) - else: - return calculate_attention_params(subblock_config, n_embd, n_head) - raise_unknown_subblock_config_error(subblock_config) + + _config = copy.deepcopy(config) + lm_config = descriptor.get_language_model_config(_config) + lm_config.num_hidden_layers = 1 + + block_configs = maybe_cast_block_configs([block_config]) + _config.block_configs = block_configs + if lm_config is not _config: + lm_config.block_configs = block_configs + + # Replaced earlier pattern: + # with EmptyInitOnDevice("meta"), deci_x_patcher(..., block_configs=block_configs): + # model = init_model_from_config(_config, ...) + # That fails on GPT-OSS with recent Transformers: ``deci_x_patcher`` runs + # ``attn_no_op_post_init`` / ``mlp_no_op_post_init`` inside ``DecoderLayer.__init__``, so norms + # / attn / mlp are swapped for placeholders before ``GptOssModel.post_init`` runs; ``post_init`` + # then raises ``ValueError`` (e.g. ``post_attention_layernorm`` in ``_keep_in_fp32_modules`` no + # longer matches the tree). Below we merge per-layer fields manually, init without the patcher, + # then call the same descriptor no-op hooks on the built layer (equivalent param count for + # ``num_hidden_layers == 1``). + + # ``block_config_to_layer_overrides`` may include keys with value ``None``; we omit those so + # ``lm_config.update`` does not overwrite existing fields with ``None`` (same rule as + # ``override_config_with_block_configs`` inside ``deci_x_patcher``). + layer_overrides = descriptor.block_config_to_layer_overrides(block_configs[0]) + lm_config.update({k: v for k, v in layer_overrides.items() if v is not None}) + + with EmptyInitOnDevice("meta"): + model = init_model_from_config( + _config, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + decoder_layer = model.get_submodule(descriptor.layer_block_name(index=0)) + if attn_no_op: + descriptor.attn_no_op_post_init(decoder_layer) + if ffn_no_op: + descriptor.mlp_no_op_post_init(decoder_layer) + return sum(p.numel() for p in decoder_layer.parameters()) def calc_subblock_active_params( - subblock_config: FFNConfig | AttentionConfig, + sublayer_config: FFNConfig | AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], n_embd: int, - n_head: int, moe_stats_file: str, batch_size: int, block_idx: int, ) -> int: - if not (isinstance(subblock_config, FFNConfig) and subblock_config.is_moe): - return calculate_subblock_params(subblock_config, n_embd, n_head) - else: - return estimate_moe_active_params( - subblock_config, n_embd, moe_stats_file, batch_size, block_idx - ) + if not (isinstance(sublayer_config, FFNConfig) and sublayer_config.is_moe): + return calculate_subblock_params(model_config, sublayer_config, descriptor) + return estimate_moe_active_params( + sublayer_config, n_embd, moe_stats_file, batch_size, block_idx + ) def load_moe_stats(stats_file: str) -> dict: @@ -168,6 +243,8 @@ def estimate_moe_active_params( def calculate_attention_memory( attention_config: AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], batch_size: int, prefill_seq_len: int, generation_seq_len: int, @@ -193,7 +270,7 @@ def calculate_attention_memory( total_num_tokens = seq_len * (batch_size + prefill_queue_size) kv_cache_size = total_num_tokens * kv_dim query_prefill_size = seq_len * n_embd if allocate_prefill_query else 0 - num_params = calculate_attention_params(attention_config, n_embd, n_head) + num_params = calculate_subblock_params(model_config, attention_config, descriptor) total_memory = ( kv_cache_size * sizeof_dtype(kv_cache_dtype) + query_prefill_size * sizeof_dtype(weights_dtype) @@ -203,52 +280,23 @@ def calculate_attention_memory( return {"memory_mib": total_memory, "kv_cache_memory_mib": kv_cache_memory} -def calculate_attention_params( - attention_config: AttentionConfig, - n_embd: int, - n_head: int, -) -> int: - kv_dim = calculate_kv_dim(attention_config.num_key_value_heads, n_head, n_embd) - return ( - n_embd * n_embd * 2 # Wq + Wo - + n_embd * kv_dim # Wk + Wv - + n_embd # rms norm - ) - - def calculate_mamba_memory( - mamba_config: MambaConfig, - n_embd: int, + attention_config: AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], batch_size: int, weights_dtype: torch.dtype, kv_cache_dtype: torch.dtype, ) -> int: + assert attention_config.mamba is not None + mamba_config = attention_config.mamba + num_params = calculate_subblock_params(model_config, attention_config, descriptor) return ( - calculate_mamba_params(mamba_config, n_embd) * sizeof_dtype(weights_dtype) + num_params * sizeof_dtype(weights_dtype) + calculate_mamba_state_size(mamba_config, batch_size) * sizeof_dtype(kv_cache_dtype) ) / 2**20 -def calculate_mamba_params( - mamba_config: MambaConfig, - n_embd: int, -) -> int: - d_inner, in_proj_dim, conv_dim, kernel_size = _calculate_mamba_intermediates(mamba_config) - param_shapes = { - "A_log": (mamba_config.num_heads,), - "D": (mamba_config.num_heads,), - "conv1d.bias": (conv_dim,), - "conv1d.weight": (conv_dim, 1, kernel_size), - "dt_bias": (mamba_config.num_heads,), - "in_proj.weight": (in_proj_dim, n_embd), - "norm.weight": (d_inner,), - "out_proj.weight": (n_embd, d_inner), - } - mamba_mixer_params = sum([math.prod(shape) for shape in param_shapes.values()]) - rms_norm_params = n_embd - return mamba_mixer_params + rms_norm_params - - def calculate_mamba_state_size( mamba_config: MambaConfig, batch_size: int, @@ -271,60 +319,18 @@ def _calculate_mamba_intermediates(mamba_config: MambaConfig) -> tuple[int, ...] return d_inner, in_proj_dim, conv_dim, kernel_size -def calculate_linear_memory( - n_embd: int, - weights_dtype: torch.dtype, -) -> float: - return calculate_linear_params(n_embd) * sizeof_dtype(weights_dtype) / 2**20 - - -def calculate_linear_params( - n_embd: int, -) -> int: - return n_embd**2 + n_embd - - def calculate_ffn_memory( ffn_config: FFNConfig, - n_embd: int, - weights_dtype: torch.dtype, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + weights_dtype: torch.dtype | str, + experts_dtype: torch.dtype | str | None = None, ) -> float: - num_params = calculate_ffn_params(ffn_config, n_embd) + # TODO: How to separate between expert weights and the rest for any model (same as puzzletron). + num_params = calculate_subblock_params(model_config, ffn_config, descriptor) return num_params * sizeof_dtype(weights_dtype) / 2**20 -def calculate_ffn_params( - ffn_config: FFNConfig, - n_embd: int, -) -> float: - if ffn_config.is_moe: - return calculate_moe_params(ffn_config, n_embd) - else: - return calculate_dense_ffn_params(ffn_config, n_embd) - - -def calculate_dense_ffn_params( - ffn_config: FFNConfig, - n_embd: int, -) -> int: - intermediate_size = ffn_config.intermediate_size - num_linear_layers = 3 if getattr(ffn_config, "gated", True) else 2 - rms_norm_params = n_embd - return n_embd * intermediate_size * num_linear_layers + rms_norm_params - - -def calculate_moe_params( - ffn_config: FFNConfig, - n_embd: int, -) -> int: - with torch.device("meta"): - config = DeciLMConfig(hidden_size=n_embd) - moe = DeciLMMoe(config, ffn_config) - moe_params = sum(p.numel() for p in moe.parameters()) - layernorm_params = n_embd - return moe_params + layernorm_params - - def calculate_non_block_memory( n_embd: int, vocab_size: int, diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index 0b8a3e72fe..1dcc6c1b22 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -31,6 +31,7 @@ from immutabledict import immutabledict from omegaconf import DictConfig, ListConfig, OmegaConf from tqdm import tqdm +from transformers import PretrainedConfig from modelopt.torch.puzzletron.anymodel.model_descriptor import ( ModelDescriptor, @@ -72,6 +73,8 @@ def calculate_subblock_stats( calc_subblock_stats_config: DictConfig, teacher_dir: Path, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], master_puzzle_dir: Path, subblock_configs: list[immutabledict[str, AttentionConfig | FFNConfig]], batch_size: int, @@ -167,14 +170,22 @@ def calculate_subblock_stats( weights_dtype, kv_cache_dtype, allocate_prefill_query, + model_config=model_config, + descriptor=descriptor, ) if not isinstance(subblock_memory, dict): subblock_memory = {"memory_mib": subblock_memory, "kv_cache_memory_mib": 0.0} - subblock_params = calculate_subblock_params(subblock_config, n_embd, n_head) + subblock_params = calculate_subblock_params(model_config, subblock_config, descriptor) if moe_stats_file is not None: subblock_active_params = calc_subblock_active_params( - subblock_config, n_embd, n_head, moe_stats_file, batch_size, parent_layer_indices[0] + subblock_config, + model_config, + descriptor, + n_embd, + moe_stats_file, + batch_size, + parent_layer_indices[0], ) else: subblock_active_params = subblock_params @@ -337,6 +348,8 @@ def calculate_subblock_stats_for_puzzle_dir( curr_subblock_stats = calculate_subblock_stats( calc_subblock_stats_config, teacher_dir=teacher_dir, + model_config=model_config, + descriptor=descriptor, master_puzzle_dir=master_puzzle_dir, subblock_configs=subblock_configs, batch_size=batch_size, diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 020afdfadd..54e2bdafd5 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -31,8 +31,9 @@ from typing import Any, BinaryIO import torch +import transformers from safetensors.torch import save_file as safe_save_file -from transformers import AutoConfig, PretrainedConfig, PreTrainedModel +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.utils import SAFE_WEIGHTS_INDEX_NAME @@ -121,6 +122,34 @@ def load_model_config( return config +def _get_model_class_from_config(config: PretrainedConfig) -> type: + """Resolve HuggingFace model class from ``config.architectures`` (see puzzletron checkpoint_utils_hf).""" + if hasattr(config, "architectures") and config.architectures: + model_class_name = config.architectures[0] + if hasattr(transformers, model_class_name): + return getattr(transformers, model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, " + "falling back to AutoModelForCausalLM" + ) + return AutoModelForCausalLM + + +def init_model_from_config( + config: PretrainedConfig, + *, + trust_remote_code: bool = True, + **kwargs, +) -> PreTrainedModel: + """Build a model from config on meta/uninitialized weights (used e.g. for subblock param counts).""" + model_class = _get_model_class_from_config(config) + if model_class is AutoModelForCausalLM: + return model_class.from_config(config, trust_remote_code=trust_remote_code, **kwargs) + # Concrete model classes (e.g. GptOssForCausalLM): _from_config forwards kwargs to __init__, + # which does not accept trust_remote_code (only AutoModel uses it when loading custom code). + return model_class._from_config(config, **kwargs) + + def save_checkpoint( model: PreTrainedModel, checkpoint_dir: Path | str, diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index f3f49bed27..88635a6aaf 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -304,7 +304,7 @@ def _assert_mip_solutions(puzzle_dir: Path, hf_model_name: str): "meta-llama/Llama-3.2-3B-Instruct": 4.816886901855469, "mistralai/Mistral-Small-24B-Instruct-2501": 4.709150314331055, # TODO: not reproducible in CI, skipping for now - # "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 4.7737884521484375, + # "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 4.733944892883301, "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 4.79390811920166, "openai/gpt-oss-20b": 4.689250946044922, "Qwen/Qwen2.5-7B-Instruct": 4.778186798095703, @@ -315,27 +315,27 @@ def _assert_mip_solutions(puzzle_dir: Path, hf_model_name: str): # Expected teacher memory from subblock_stats (MiB) EXPECTED_TEACHER_MEMORY_MIB = { - "meta-llama/Llama-3.1-8B-Instruct": 386.22705078125, - "meta-llama/Llama-3.2-3B-Instruct": 386.22705078125, - "mistralai/Mistral-Small-24B-Instruct-2501": 386.22705078125, - "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 552.47607421875, - "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 193.16357421875, - "openai/gpt-oss-20b": 456.75830078125, - "Qwen/Qwen2.5-7B-Instruct": 386.22705078125, - "Qwen/Qwen3-8B": 386.22705078125, - "Qwen/Qwen3-VL-30B-A3B-Instruct": 420.74267578125, + "meta-llama/Llama-3.1-8B-Instruct": 395.60205078125, + "meta-llama/Llama-3.2-3B-Instruct": 395.60205078125, + "mistralai/Mistral-Small-24B-Instruct-2501": 395.60205078125, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 202.10107421875, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 202.10107421875, + "openai/gpt-oss-20b": 437.302490234375, + "Qwen/Qwen2.5-7B-Instruct": 386.228515625, + "Qwen/Qwen3-8B": 395.60302734375, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 406.11865234375, } # Expected total teacher params from subblock_stats EXPECTED_TEACHER_NUM_PARAMS = { - "meta-llama/Llama-3.1-8B-Instruct": 1167616.0, - "meta-llama/Llama-3.2-3B-Instruct": 1167616.0, - "mistralai/Mistral-Small-24B-Instruct-2501": 1167616.0, - "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 188993280.0, - "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 610048.0, - "openai/gpt-oss-20b": 38146304.0, - "Qwen/Qwen2.5-7B-Instruct": 1167616.0, - "Qwen/Qwen3-8B": 1167616.0, - "Qwen/Qwen3-VL-30B-A3B-Instruct": 19263744.0, + "meta-llama/Llama-3.1-8B-Instruct": 6082816.0, + "meta-llama/Llama-3.2-3B-Instruct": 6082816.0, + "mistralai/Mistral-Small-24B-Instruct-2501": 6082816.0, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 5295872.0, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 5295872.0, + "openai/gpt-oss-20b": 27945856.0, + "Qwen/Qwen2.5-7B-Instruct": 1168384.0, + "Qwen/Qwen3-8B": 6083328.0, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 11596544.0, } From 928036edae2184fc2f6a60de2ddecfc38e2ed37e Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 24 Mar 2026 14:10:48 +0100 Subject: [PATCH 50/62] Dkorzekwa/decilm cleanup post subblockstats (#1103) ### What does this PR do? Removing unused code from modelopt/torch/puzzletron/decilm/deci_lm_hf_code - completed. ## Summary by CodeRabbit * **Chores** * Removed vendored Transformers modules including activation functions, cache utilities, and configuration classes. * Removed custom DeciLM configuration class. * Updated internal type signatures from DeciLMConfig to standard Hugging Face PretrainedConfig. * Simplified internal function signatures by removing unused parameters. * Enhanced code quality checks via pre-commit configuration updates. --------- Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 7 +- .../deci_lm_hf_code/configuration_decilm.py | 204 --- .../decilm/deci_lm_hf_code/modeling_decilm.py | 85 +- .../transformers_4_44_2__activations.py | 254 --- .../transformers_4_44_2__cache_utils.py | 1447 ----------------- ...ransformers_4_44_2__configuration_llama.py | 219 --- ...ransformers_4_44_2__modeling_rope_utils.py | 574 ------- ...ansformers_4_51_3__configuration_llama4.py | 447 ----- .../build_replacement_library.py | 2 +- .../replacement_library.py | 3 +- .../subblock_stats/calc_subblock_stats.py | 5 +- .../tools/bypassed_training/child_init.py | 24 +- .../puzzletron/tools/checkpoint_utils.py | 10 +- .../puzzletron/tools/checkpoint_utils_hf.py | 4 +- 14 files changed, 27 insertions(+), 3258 deletions(-) delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py delete mode 100644 modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b278013bb8..3f570a9200 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,15 +29,13 @@ repos: exclude: > (?x)^( ^examples/specdec_bench/specdec_bench/datasets/speed\.py$| - modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config\.py| - modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py + modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config\.py )$ - id: ruff-format exclude: > (?x)^( ^examples/specdec_bench/specdec_bench/datasets/speed\.py$| - modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config\.py| - modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py + modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config\.py )$ - repo: https://github.com/pre-commit/mirrors-mypy @@ -95,7 +93,6 @@ repos: modelopt/torch/speculative/eagle/utils.py| modelopt/torch/speculative/plugins/transformers.py| modelopt/torch/utils/plugins/megatron_mmlu.py| - modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py| examples/chained_optimizations/bert_prune_distill_quantize.py| examples/deepseek/quantize_to_nvfp4.py| examples/deepseek/ptq.py| diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py deleted file mode 100644 index 34a7e8cfcf..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py +++ /dev/null @@ -1,204 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors - -import copy -import dataclasses -import warnings -from typing import Any - -from transformers.utils import is_flash_attn_2_available # , is_torch_sdpa_available - -from .block_config import BlockConfig -from .transformers_4_44_2__configuration_llama import LlamaConfig - -# fakes imports to make AutoConfig infer dependencies -from .transformers_4_44_2__modeling_rope_utils import rope_config_validation -from .transformers_4_51_3__configuration_llama4 import Llama4Config - -# make sure that auto-formatting doesn't remove the fake imports -rope_config_validation -Llama4Config - - -class DeciLMConfig(LlamaConfig): - model_type = "nemotron-nas" - - # Mapping from global attribute names to their per-layer equivalents in block_configs - # Format: 'global_name': ('block_section', 'layer_name') - PER_LAYER_ATTRIBUTE_MAPPING = { - "intermediate_size": ("ffn", "intermediate_size"), - "num_key_value_heads": ( - "attention", - "n_heads_in_group", - ), # Note: derived value (num_heads / num_kv_heads) - "hidden_act": ("ffn", "hidden_act"), - "sliding_window": ("attention", "window_length"), # Note: different name! - } - - def __init__( - self, - block_configs: list[dict] | list[BlockConfig] | None = None, - position_embedding_type: str = "rope", - llama4_attn_implementation: str | None = None, - block_return_only_hidden_states: bool = False, - router_aux_loss_coef: float = 0.01, - router_z_loss_coef: float = 0.0, - output_router_logits: bool = False, - head_dim: int | None = 128, - o_proj_bias: bool = False, - **kwargs, - ): - self.block_configs: list[BlockConfig] = block_configs - if self.block_configs is not None: - if isinstance(self.block_configs[0], dict): - self.block_configs = [BlockConfig(**conf) for conf in self.block_configs] - - assert position_embedding_type in ["rope", "rope_llama4", "none", "mistral_yarn"] - self.position_embedding_type = position_embedding_type - if self.position_embedding_type == "none": - self.rope_theta = None - self.rope_scaling = None - - self.block_return_only_hidden_states = block_return_only_hidden_states - self.router_aux_loss_coef = router_aux_loss_coef - self.router_z_loss_coef = router_z_loss_coef - self.output_router_logits = output_router_logits - self.o_proj_bias = o_proj_bias - - self._choose_llama4_attn_implementation(llama4_attn_implementation) - attn_implementation = self._choose_llama3_attn_implementation(kwargs) - super().__init__(attn_implementation=attn_implementation, **kwargs) - self.head_dim = ( - head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - ) - - # Delete per-layer attributes after parent init (they should only exist in block_configs) - self._delete_per_layer_attributes() - - if self.block_configs is not None: - assert len(self.block_configs) == self.num_hidden_layers - - def _delete_per_layer_attributes(self): - """Delete per-layer attributes that should only exist in block_configs. - - These attributes are intentionally deleted AFTER super().__init__() to ensure - they don't exist at the global config level. Deleting them (rather than setting - to None) makes it clear they shouldn't be accessed globally. - """ - present_attrs = { - attr: getattr(self, attr) - for attr in self.PER_LAYER_ATTRIBUTE_MAPPING - if hasattr(self, attr) - } - if present_attrs: - warnings.warn( - f"Deleting global per-layer attributes (should only be in block_configs): {present_attrs}", - UserWarning, - stacklevel=3, - ) - for attr in self.PER_LAYER_ATTRIBUTE_MAPPING: - if hasattr(self, attr): - delattr(self, attr) - - def _choose_llama4_attn_implementation(self, llama4_attn_implementation): - self.llama4_attn_implementation = llama4_attn_implementation - if self.llama4_attn_implementation is None: - _print_once("auto-setting llama4_attn_implementation to sdpa") - self.llama4_attn_implementation = "sdpa" - - def _choose_llama3_attn_implementation(self, kwargs: dict[str, Any]) -> str: - attn_implementation = kwargs.pop("attn_implementation", None) - if attn_implementation is None and is_flash_attn_2_available(): - _print_once("auto-setting attn_implementation (for Llama3 layers) to flash_attention_2") - attn_implementation = "flash_attention_2" - - if self.block_configs is not None: - using_unshifted_sink = any( - block_config.attention.unshifted_sink for block_config in self.block_configs - ) - if using_unshifted_sink and attn_implementation != "eager": - warnings.warn( - "Forcing attn_implementation='eager' since some attention layers use unshifted sink" - ) - attn_implementation = "eager" - return attn_implementation - - def to_dict(self) -> dict[str, Any]: - """Convert config to dictionary, removing per-layer-only attributes.""" - self_dict = super().to_dict() - if self.block_configs is not None: - self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs] - - # Remove global keys that should only exist per-layer in block_configs - for key in self.PER_LAYER_ATTRIBUTE_MAPPING: - self_dict.pop(key, None) - - return self_dict - - def set_block_configs(self, block_configs: list[BlockConfig]) -> "DeciLMConfig": - new_model_config = copy.deepcopy(self) - new_model_config.block_configs = block_configs - new_model_config.num_hidden_layers = len(block_configs) - return new_model_config - - def get_num_hidden_layers(self) -> int: - return self.num_hidden_layers - - def get_hidden_size(self) -> int: - return self.hidden_size - - def get_embedding_layer_name(self) -> str: - return "model.embed_tokens" - - def get_final_layer_norm_layer_name(self) -> str: - return "model.norm" - - def get_lm_head_layer_name(self) -> str: - return "lm_head" - - def get_layers_layer_name(self) -> str: - return "model.layers" - - def get_block_config(self, layer_idx: int | tuple[int, ...]) -> BlockConfig: - if isinstance(layer_idx, tuple) and len(layer_idx) == 1: - layer_idx = layer_idx[0] - - if isinstance(layer_idx, int): - return self.block_configs[layer_idx] - - external_layer_idx, internal_layer_idx = layer_idx - return self.block_configs[external_layer_idx].parallel_blocks[internal_layer_idx] - - def get_min_attention_chunk_size(self) -> int | None: - min_chunk_size = float("inf") - for block_config in self.block_configs: - if block_config.attention.llama4 is not None: - attention_chunk_size = block_config.attention.llama4.attention_chunk_size - if attention_chunk_size is not None: - min_chunk_size = min(min_chunk_size, attention_chunk_size) - - if min_chunk_size == float("inf"): - return None - return min_chunk_size - - -def _print_once(message: str): - if not hasattr(_print_once, "was_printed"): - _print_once.was_printed = set() - if message not in _print_once.was_printed: - _print_once.was_printed.add(message) - print(message) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py index 915c111be5..0a0f8ab1ef 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py @@ -15,92 +15,13 @@ # Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved. # -# Pared-down DeciLM building blocks for Model-Optimizer puzzletron / AnyModel flows. -# The full HF DeciLM decoder stack (decoder layers, attention, rope, etc.) is not vendored here; -# AnyModel loads real models via transformers. This module keeps shared helpers: RMSNorm, -# gated MLP, and LMHead for replacement / validation code. +# Small nn helpers for puzzletron pipeline code. Model configs come from HuggingFace ``AutoConfig`` (AnyModel). +# ``LMHead`` is a distinct ``nn.Linear`` subclass so pipeline / FSDP code can target it explicitly +# (see ``validate_runtime_pipeline``). # mypy: ignore-errors -import torch -import torch.nn.functional as F from torch import nn -from .block_config import FFNConfig -from .configuration_decilm import DeciLMConfig -from .transformers_4_44_2__activations import ACT2FN - - -class DeciLMRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - DeciLMRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -def sparsity_backward_hook(*args, **kwargs): - raise NotImplementedError( - "No support for sparsity when training HF DeciLM (inference is ok though)" - ) - - -class DeciLMGatedMLP(nn.Module): - def __init__( - self, - config: DeciLMConfig, - ffn_config: FFNConfig, - ): - super().__init__() - self.config = config - self.ffn_config = ffn_config - self.hidden_size = config.hidden_size - self.intermediate_size = ffn_config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[getattr(ffn_config, "hidden_act", "silu")] - - if ffn_config.sparsify is not None: - self.register_full_backward_hook(sparsity_backward_hook) - - def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], - dim=-1, - ) - up_proj = torch.cat( - [F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - class LMHead(nn.Linear): """ diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py deleted file mode 100644 index 6c964dbfc1..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from collections import OrderedDict - -import torch -from packaging import version -from torch import Tensor, nn - -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - - -class PytorchGELUTanh(nn.Module): - """ - A fast C implementation of the tanh approximation of the GeLU activation function. See - https://arxiv.org/abs/1606.08415. - - This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical - match due to rounding errors. - """ - - def __init__(self): - super().__init__() - if version.parse(torch.__version__) < version.parse("1.12.0"): - raise ImportError( - f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " - "PytorchGELUTanh. Please upgrade torch." - ) - - def forward(self, input: Tensor) -> Tensor: - return nn.functional.gelu(input, approximate="tanh") - - -class NewGELUActivation(nn.Module): - """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see - the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 - """ - - def forward(self, input: Tensor) -> Tensor: - return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) - - -class GELUActivation(nn.Module): - """ - Original Implementation of the GELU activation function in Google BERT repo when initially created. For - information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + - torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional - Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 - """ - - def __init__(self, use_gelu_python: bool = False): - super().__init__() - if use_gelu_python: - self.act = self._gelu_python - else: - self.act = nn.functional.gelu - - def _gelu_python(self, input: Tensor) -> Tensor: - return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) - - def forward(self, input: Tensor) -> Tensor: - return self.act(input) - - -class FastGELUActivation(nn.Module): - """ - Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs - """ - - def forward(self, input: Tensor) -> Tensor: - return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) - - -class QuickGELUActivation(nn.Module): - """ - Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs - """ - - def forward(self, input: Tensor) -> Tensor: - return input * torch.sigmoid(1.702 * input) - - -class ClippedGELUActivation(nn.Module): - """ - Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as - it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to - https://arxiv.org/abs/2004.09602. - - Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when - initially created. - - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + - torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 - """ - - def __init__(self, min: float, max: float): - if min > max: - raise ValueError(f"min should be < max (got min: {min}, max: {max})") - - super().__init__() - self.min = min - self.max = max - - def forward(self, x: Tensor) -> Tensor: - return torch.clip(gelu(x), self.min, self.max) - - -class AccurateGELUActivation(nn.Module): - """ - Applies GELU approximation that is faster than default and more accurate than QuickGELU. See: - https://github.com/hendrycks/GELUs - - Implemented along with MEGA (Moving Average Equipped Gated Attention) - """ - - def __init__(self): - super().__init__() - self.precomputed_constant = math.sqrt(2 / math.pi) - - def forward(self, input: Tensor) -> Tensor: - return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3)))) - - -class MishActivation(nn.Module): - """ - See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also - visit the official repository for the paper: https://github.com/digantamisra98/Mish - """ - - def __init__(self): - super().__init__() - if version.parse(torch.__version__) < version.parse("1.9.0"): - self.act = self._mish_python - else: - self.act = nn.functional.mish - - def _mish_python(self, input: Tensor) -> Tensor: - return input * torch.tanh(nn.functional.softplus(input)) - - def forward(self, input: Tensor) -> Tensor: - return self.act(input) - - -class LinearActivation(nn.Module): - """ - Applies the linear activation function, i.e. forwarding input directly to output. - """ - - def forward(self, input: Tensor) -> Tensor: - return input - - -class LaplaceActivation(nn.Module): - """ - Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See - https://arxiv.org/abs/2209.10655 - - Inspired by squared relu, but with bounded range and gradient for better stability - """ - - def forward(self, input, mu=0.707107, sigma=0.282095): - input = (input - mu).div(sigma * math.sqrt(2.0)) - return 0.5 * (1.0 + torch.erf(input)) - - -class ReLUSquaredActivation(nn.Module): - """ - Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 - """ - - def forward(self, input): - relu_applied = nn.functional.relu(input) - squared = torch.square(relu_applied) - return squared - - -class ClassInstantier(OrderedDict): - def __getitem__(self, key): - content = super().__getitem__(key) - cls, kwargs = content if isinstance(content, tuple) else (content, {}) - return cls(**kwargs) - - -ACT2CLS = { - "gelu": GELUActivation, - "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), - "gelu_fast": FastGELUActivation, - "gelu_new": NewGELUActivation, - "gelu_python": (GELUActivation, {"use_gelu_python": True}), - "gelu_pytorch_tanh": PytorchGELUTanh, - "gelu_accurate": AccurateGELUActivation, - "laplace": LaplaceActivation, - "leaky_relu": nn.LeakyReLU, - "linear": LinearActivation, - "mish": MishActivation, - "quick_gelu": QuickGELUActivation, - "relu": nn.ReLU, - "relu2": ReLUSquaredActivation, - "relu6": nn.ReLU6, - "sigmoid": nn.Sigmoid, - "silu": nn.SiLU, - "swish": nn.SiLU, - "tanh": nn.Tanh, -} -ACT2FN = ClassInstantier(ACT2CLS) - - -def get_activation(activation_string): - if activation_string in ACT2FN: - return ACT2FN[activation_string] - else: - raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") - - -# For backwards compatibility with: from activations import gelu_python -gelu_python = get_activation("gelu_python") -gelu_new = get_activation("gelu_new") -gelu = get_activation("gelu") -gelu_fast = get_activation("gelu_fast") -quick_gelu = get_activation("quick_gelu") -silu = get_activation("silu") -mish = get_activation("mish") -linear_act = get_activation("linear") diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py deleted file mode 100644 index 83d7251dda..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py +++ /dev/null @@ -1,1447 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors -import copy -import importlib.metadata -import json -import os -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from packaging import version - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import is_torchdynamo_compiling, logging - - -logger = logging.get_logger(__name__) - - -class Cache(torch.nn.Module): - """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. - """ - - def __init__(self): - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states, if there is any.""" - raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_length() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( - 0, beam_idx.to(device) - ) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( - 0, beam_idx.to(device) - ) - - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None - - -@dataclass -class CacheConfig: - """ - Base class for cache configs - """ - - cache_implementation: None - - @classmethod - def from_dict(cls, config_dict, **kwargs): - """ - Constructs a CacheConfig instance from a dictionary of parameters. - Args: - config_dict (Dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - - Returns: - CacheConfig: Instance of CacheConfig constructed from the dictionary. - """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `QuantizationConfig()` is serialized to JSON file. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - writer.write(json_string) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - return copy.deepcopy(self.__dict__) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - def to_json_string(self): - """ - Serializes this instance to a JSON formatted string. - Returns: - str: JSON formatted string representing the configuration instance. - """ - return json.dumps(self.__dict__, indent=2) + "\n" - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update - def update(self, **kwargs): - """ - Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, - returning all the unused kwargs. - - Args: - kwargs (`Dict[str, Any]`): - Dictionary of attributes to tentatively update this class. - - Returns: - `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. - """ - to_remove = [] - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - to_remove.append(key) - - # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs - - -class DynamicCache(Cache): - """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = DynamicCache() - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - """ - - def __init__(self) -> None: - super().__init__() - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - self._seen_tokens = ( - 0 # Used in `generate` to keep tally of how many tokens the cache has seen - ) - - def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return (self.key_cache[layer_idx], self.value_cache[layer_idx]) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" - return None - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def crop(self, max_length: int): - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - self._seen_tokens = max_length - for idx in range(len(self.key_cache)): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - out = [] - for i in range(0, full_batch_size, split_size): - current_split = DynamicCache() - current_split._seen_tokens = self._seen_tokens - current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] - current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] - out.append(current_split) - return out - - @classmethod - def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - cache = cls() - for idx in range(len(splits[0])): - layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) - cache.update(layer_keys, layer_values, idx) - return cache - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave( - repeats, dim=0 - ) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] - - -class OffloadedCache(DynamicCache): - """ - A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. - Useful for generating from models with very long context. - - In addition to the default CUDA stream, where all forward() computations happen, - this class uses another stream, the prefetch stream, which it creates itself. - Since scheduling of operations on separate streams happens independently, this class uses - the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. - The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to - ensure the eviction is scheduled after all computations on that cache are finished. - """ - - def __init__(self) -> None: - if not torch.cuda.is_available(): - raise RuntimeError("OffloadedCache can only be used with a GPU") - super().__init__() - self.original_device = [] - self.prefetch_stream = torch.cuda.Stream() - self.beam_idx = None # used to delay beam search operations - - def prefetch_layer(self, layer_idx: int): - "Starts prefetching the next layer cache" - if layer_idx < len(self): - with torch.cuda.stream(self.prefetch_stream): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to( - device, non_blocking=True - ) - - def evict_previous_layer(self, layer_idx: int): - "Moves the previous layer cache to the CPU" - if len(self) > 2: - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - prev_layer_idx = (layer_idx - 1) % len(self) - self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to( - "cpu", non_blocking=True - ) - self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to( - "cpu", non_blocking=True - ) - - def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: - "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." - if layer_idx < len(self): - # Evict the previous layer if necessary - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) - # Load current layer cache to its original device if not already there - original_device = self.original_device[layer_idx] - self.prefetch_stream.synchronize() - key_tensor = self.key_cache[layer_idx] - value_tensor = self.value_cache[layer_idx] - # Now deal with beam search ops which were delayed - if self.beam_idx is not None: - self.beam_idx = self.beam_idx.to(original_device) - key_tensor = key_tensor.index_select(0, self.beam_idx) - value_tensor = value_tensor.index_select(0, self.beam_idx) - # Prefetch the next layer - self.prefetch_layer((layer_idx + 1) % len(self)) - return (key_tensor, value_tensor) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Saves the beam indices and reorders the cache when the tensor is back to its device.""" - # We delay this operation until the tensors are back to their original - # device because performing torch.index_select on the CPU is very slow - del self.beam_idx - self.beam_idx = beam_idx.clone() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.original_device.append(key_states.device) - self.evict_previous_layer(layer_idx) - else: - key_tensor, value_tensor = self[layer_idx] - self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError - # if a method is not supposed to be supported in a subclass we should set it to None - from_legacy_cache = None - - to_legacy_cache = None - - -class SinkCache(Cache): - """ - A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to - generate beyond the length of its context window, without losing fluency in the conversation. As it discards past - tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Parameters: - window_length (`int`): - The length of the context window. - num_sink_tokens (`int`): - The number of sink tokens. See the original paper for more information. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - """ - - def __init__(self, window_length: int, num_sink_tokens: int) -> None: - super().__init__() - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - self.window_length = window_length - self.num_sink_tokens = num_sink_tokens - self.cos_sin_rerotation_cache = {} - self._cos_cache = None - self._sin_cache = None - self._seen_tokens = ( - 0 # Used in `generate` to keep tally of how many tokens the cache has seen - ) - - @staticmethod - def _rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_key_rotary_pos_emb( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> torch.Tensor: - rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) - return rotated_key_states - - def _get_rerotation_cos_sin( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - if key_states.shape[-2] not in self.cos_sin_rerotation_cache: - # Upcast to float32 temporarily for better accuracy - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence - original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] - shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] - original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] - shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] - rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin - rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - - self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( - rerotation_cos.to(key_states.dtype).unsqueeze(0), - rerotation_sin.to(key_states.dtype).unsqueeze(0), - ) - return self.cos_sin_rerotation_cache[key_states.shape[-2]] - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states.""" - return self.window_length - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, - `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the - rotation as the tokens are shifted. - - Return: - A tuple containing the updated key and value states. - """ - # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models - # with partially rotated position embeddings, like Phi or Persimmon. - sin = cache_kwargs.get("sin") - cos = cache_kwargs.get("cos") - partial_rotation_size = cache_kwargs.get("partial_rotation_size") - using_rope = cos is not None and sin is not None - - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the sin/cos cache, which holds sin/cos values for all possible positions - if using_rope and layer_idx == 0: - # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove - # after all RoPE models have a llama-like cache utilization. - if cos.dim() == 2: - self._cos_cache = cos - self._sin_cache = sin - else: - if self._cos_cache is None: - self._cos_cache = cos[0, ...] - self._sin_cache = sin[0, ...] - elif self._cos_cache.shape[0] < self.window_length: - self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) - self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) - - # [bsz, num_heads, seq_len, head_dim] - if len(self.key_cache) <= layer_idx: - # Empty cache - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: - # Growing cache - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) - - else: - # Shifting cache - keys_to_keep = self.key_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : - ] - - # On RoPE models, we need to recompute the Key rotation as the tokens are shifted - if using_rope: - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( - key_states, - self._cos_cache[: self.window_length], - self._sin_cache[: self.window_length], - ) - if partial_rotation_size is not None: - keys_to_keep, keys_pass = ( - keys_to_keep[..., :partial_rotation_size], - keys_to_keep[..., partial_rotation_size:], - ) - keys_to_keep = self._apply_key_rotary_pos_emb( - keys_to_keep, rerotation_cos, rerotation_sin - ) - if partial_rotation_size is not None: - keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) - - # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens - sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] - self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) - - sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] - values_to_keep = self.value_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : - ] - self.value_cache[layer_idx] = torch.cat( - [sink_values, values_to_keep, value_states], dim=-2 - ) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - -class StaticCache(Cache): - """ - Static Cache class to be used with `torch.compile(model)` and `torch.export()`. - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - device (`torch.device`): - The device on which the cache should be initialized. Should be the same as the layer. - dtype (*optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - """ - - def __init__( - self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None - ) -> None: - super().__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = ( - config.max_position_embeddings if max_cache_len is None else max_cache_len - ) - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // config.num_attention_heads - ) - - self.dtype = dtype if dtype is not None else torch.float32 - self.num_key_value_heads = ( - config.num_attention_heads - if config.num_key_value_heads is None - else config.num_key_value_heads - ) - - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - for idx in range(config.num_hidden_layers): - new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - # Notes: - # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case - # it is not needed anyway) - # 2. `torch.export()` requires mutations to be registered as buffers. - if not is_torchdynamo_compiling(): - self.register_buffer( - f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device) - ) - self.register_buffer( - f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device) - ) - new_layer_key_cache = getattr(self, f"key_cache_{idx}") - new_layer_value_cache = getattr(self, f"value_cache_{idx}") - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - self._seen_tokens = ( - 0 # Used in `generate` to keep tally of how many tokens the cache has seen - ) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - cache_position = cache_kwargs.get("cache_position") - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - # return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - return self._seen_tokens - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states.""" - return self.max_cache_len - - def reset(self): - self._seen_tokens = 0 - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class SlidingWindowCache(StaticCache): - """ - Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - - indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) - - We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - device (`torch.device`): - The device on which the cache should be initialized. Should be the same as the layer. - dtype (*optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - """ - - def __init__( - self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None - ) -> None: - super().__init__(config, max_batch_size, max_cache_len, device, dtype) - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - max_cache_len = min(config.sliding_window, max_cache_len) - super().__init__( - config=config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - ) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor]: - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - - # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) - if cache_position.shape[0] > self.max_cache_len: - k_out = key_states[:, :, -self.max_cache_len :, :] - v_out = value_states[:, :, -self.max_cache_len :, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones( - self.max_cache_len, dtype=torch.long, device=value_states.device - ).cumsum(0) - cache_position = cache_position.clamp(0, self.max_cache_len - 1) - to_shift = cache_position >= self.max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len - - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - try: - cache_position.to(device=k_out.device) - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - - return k_out, v_out - - def get_max_length(self) -> Optional[int]: - # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is - return None - - def reset(self): - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class EncoderDecoderCache(Cache): - """ - Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and - cross-attention caches. - - Example: - - ```python - >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") - >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") - - >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") - - >>> # Prepare cache classes for encoder and decoder and pass it to model's forward - >>> self_attention_cache = DynamicCache() - >>> cross_attention_cache = DynamicCache() - >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - - """ - - def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): - super().__init__() - self.self_attention_cache = self_attention_cache - self.cross_attention_cache = cross_attention_cache - - self.is_updated = {} - for layer_idx in range(len(cross_attention_cache.key_cache)): - self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) - - def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], - ) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.self_attention_cache) - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" - legacy_cache = () - if len(self.cross_attention_cache) > 0: - for self_attn, cross_attn in zip( - self.self_attention_cache.to_legacy_cache(), - self.cross_attention_cache.to_legacy_cache(), - ): - legacy_cache += (self_attn + cross_attn,) - else: - legacy_cache = self.self_attention_cache.to_legacy_cache() - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "EncoderDecoderCache": - """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" - cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx][:2] - cache.self_attention_cache.update(key_states, value_states, layer_idx) - if len(past_key_values[layer_idx]) > 2: - key_states, value_states = past_key_values[layer_idx][2:] - cache.cross_attention_cache.update(key_states, value_states, layer_idx) - cache.is_updated[layer_idx] = True - return cache - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.self_attention_cache.key_cache) <= layer_idx: - return 0 - return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - if hasattr(self.self_attention_cache, "reset"): - self.self_attention_cache.reset() - if hasattr(self.cross_attention_cache, "reset"): - self.cross_attention_cache.reset() - elif not hasattr(self.self_attention_cache, "reset") and not hasattr( - self.cross_attention_cache, "reset" - ): - raise ValueError( - "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " - "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " - f"Got {self.self_attention_cache.__str__()} for the self attention cache and " - f"{self.cross_attention_cache.__str__()} for the cross attention cache." - ) - for layer_idx in self.is_updated: - self.is_updated[layer_idx] = False - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - self.self_attention_cache.reorder_cache(beam_idx) - self.cross_attention_cache.reorder_cache(beam_idx) - - def check_dynamic_cache(self, method: str): - if not ( - isinstance(self.self_attention_cache, DynamicCache) - and isinstance(self.cross_attention_cache, DynamicCache) - ): - raise ValueError( - f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " - f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." - ) - - # TODO(gante, sanchit-gandhi): move following functionality into `.generate` - def crop(self, maximum_length: int): - """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" - self.check_dynamic_cache(self.crop.__name__) - self.self_attention_cache.crop(maximum_length) - - def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - self.check_dynamic_cache(self.batch_split.__name__) - self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) - cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) - - out = [] - for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): - out.append(EncoderDecoderCache(self_attn, cross_attn)) - return out - - @classmethod - def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - for idx in range(len(splits[0])): - layer_keys = torch.cat( - [current.self_attention_cache.key_cache[idx] for current in splits], dim=0 - ) - layer_values = torch.cat( - [current.self_attention_cache.value_cache[idx] for current in splits], dim=0 - ) - self_attention_cache.update(layer_keys, layer_values, idx) - - layer_keys = torch.cat( - [current.cross_attention_cache.key_cache[idx] for current in splits], dim=0 - ) - layer_values = torch.cat( - [current.cross_attention_cache.value_cache[idx] for current in splits], dim=0 - ) - cross_attention_cache.update(layer_keys, layer_values, idx) - return cls(self_attention_cache, cross_attention_cache) - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_repeat_interleave.__name__) - self.self_attention_cache.batch_repeat_interleave(repeats) - self.cross_attention_cache.batch_repeat_interleave(repeats) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_select_indices.__name__) - self.self_attention_cache.batch_select_indices(indices) - self.cross_attention_cache.batch_select_indices(indices) - - -class HybridCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention - and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention - and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. - - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - device (`torch.device`, *optional*, defaults to `"cpu"`): - The device on which the cache should be initialized. Should be the same as the layer. - dtype (*optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - """ - - def __init__( - self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - self.max_cache_len = max_cache_len - self.max_batch_size = max_batch_size - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // config.num_attention_heads - ) - - self.dtype = dtype if dtype is not None else torch.float32 - self.num_key_value_heads = ( - config.num_attention_heads - if config.num_key_value_heads is None - else config.num_key_value_heads - ) - self.is_sliding = torch.tensor( - [not bool(i % 2) for i in range(config.num_hidden_layers)], - dtype=torch.bool, - device=device, - ) - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - global_cache_shape = ( - max_batch_size, - self.num_key_value_heads, - max_cache_len, - self.head_dim, - ) - sliding_cache_shape = ( - max_batch_size, - self.num_key_value_heads, - min(config.sliding_window, max_cache_len), - self.head_dim, - ) - for i in range(config.num_hidden_layers): - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def _sliding_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - if cache_position.shape[0] > max_cache_len: - k_out = key_states[:, :, -max_cache_len:, :] - v_out = value_states[:, :, -max_cache_len:, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) - cache_position = cache_position.clamp(0, max_cache_len - 1) - to_shift = cache_position >= max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % max_cache_len - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - return k_out, v_out - - def _static_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor]: - cache_position = cache_kwargs.get("cache_position") - sliding_window = cache_kwargs.get("sliding_window") - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - if sliding_window: - update_fn = self._sliding_update - else: - update_fn = self._static_update - - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) - - def get_max_length(self) -> Optional[int]: - # in theory there is no limit because the sliding window size is fixed - # no matter how long the sentence is - return self.max_cache_len - - def get_seq_length(self, layer_idx: Optional[int] = 0): - return None - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class MambaCache: - """ - Cache for mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - dtype (*optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Attributes: - dtype: (`torch.dtype`): - The default `dtype` used to initializing the cache. - intermediate_size: (`int`): - Model's intermediate_size taken from config. - ssm_state_size: (`int`): - Model's state_size taken from config. - conv_kernel_size: (`int`): - Model's convolution kernel size taken from config - conv_states: (`torch.Tensor`): - A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. - ssm_states: (`torch.Tensor`): - A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states - - Example: - - ```python - >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache - - >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") - - >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv = outputs.past_key_values - ``` - """ - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: Optional[str] = None, - **kwargs, - ): - self.dtype = dtype - self.max_batch_size = max_batch_size - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - - self.conv_states: torch.Tensor = torch.zeros( - config.num_hidden_layers, - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states: torch.Tensor = torch.zeros( - config.num_hidden_layers, - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=dtype, - ) - - torch._dynamo.mark_static_address(self.conv_states) - torch._dynamo.mark_static_address(self.ssm_states) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py deleted file mode 100644 index 461996f742..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py +++ /dev/null @@ -1,219 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from .transformers_4_44_2__modeling_rope_utils import rope_config_validation - - -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, - Llama 2 up to 4096, CodeLlama up to 16384. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to - understand more about it. This value is necessary to ensure exact reproducibility of the pretraining - results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. - - ```python - >>> from transformers import LlamaModel, LlamaConfig - - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() - - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - mlp_bias=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.mlp_bias = mlp_bias - - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. - if self.rope_scaling is not None and "type" in self.rope_scaling: - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py deleted file mode 100644 index 761c2b6402..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py +++ /dev/null @@ -1,574 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Optional, Tuple - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import is_torch_available, logging - - -logger = logging.get_logger(__name__) - - -if is_torch_available(): - import torch - - -def _compute_default_rope_parameters( - config: Optional[PretrainedConfig] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - **rope_kwargs, -) -> Tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) - return inv_freq, attention_factor - - -def _compute_linear_scaling_rope_parameters( - config: Optional[PretrainedConfig] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - **rope_kwargs, -) -> Tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - factor = rope_kwargs["factor"] - elif config is not None: - factor = config.rope_scaling["factor"] - - # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) - - # Then applies linear scaling to the frequencies. - # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so - # applying scaling to the inverse frequencies is equivalent. - inv_freq /= factor - return inv_freq, attention_factor - - -def _compute_dynamic_ntk_parameters( - config: Optional[PretrainedConfig] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - **rope_kwargs, -) -> Tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length, used to update the dynamic RoPE at inference time. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - max_position_embeddings = rope_kwargs["max_position_embeddings"] - factor = rope_kwargs["factor"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - max_position_embeddings = config.max_position_embeddings - factor = config.rope_scaling["factor"] - - attention_factor = 1.0 # Unused in this type of RoPE - - # seq_len: default to max_position_embeddings, e.g. at init time - seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings - - # Compute the inverse frequencies - base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) - return inv_freq, attention_factor - - -def _compute_yarn_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs -) -> Tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with NTK scaling. Please refer to the - [original paper](https://arxiv.org/abs/2309.00071) - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # No need to keep BC with yarn, unreleased when this new pattern was created. - if len(rope_kwargs) > 0: - raise ValueError( - f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" - ) - - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - max_position_embeddings = config.max_position_embeddings - factor = config.rope_scaling["factor"] - - # Sets the attention factor as suggested in the paper - attention_factor = config.rope_scaling.get("attention_factor") - if attention_factor is None: - attention_factor = 0.1 * math.log(factor) + 1.0 - - # Optional config options - # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) - beta_fast = config.rope_scaling.get("beta_fast") or 32 - beta_slow = config.rope_scaling.get("beta_slow") or 1 - - # Compute the inverse frequencies - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): - """Inverse dimension formula to find the dimension based on the number of rotations""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): - """Find dimension range bounds based on rotations""" - low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs - # to expand the possible context length. In other words, interpolation = apply scaling factor. - pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (factor * pos_freqs) - - low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) - - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor - ) - - return inv_freq, attention_factor - - -def _compute_longrope_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs -) -> Tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with LongRoPE scaling. Please refer to the - [original implementation](https://github.com/microsoft/LongRoPE) - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling - # No need to keep BC with longrope, unreleased when this new pattern was created. - if len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got " - f"{rope_kwargs}" - ) - - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - long_factor = config.rope_scaling["long_factor"] - short_factor = config.rope_scaling["short_factor"] - factor = config.rope_scaling.get("factor") - attention_factor = config.rope_scaling.get("attention_factor") - - # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a - # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two - # values to compute the default attention scaling factor, instead of using `factor`. - if hasattr(config, "original_max_position_embeddings"): - max_position_embeddings = config.original_max_position_embeddings - expanded_max_position_embeddings = config.max_position_embeddings - factor = expanded_max_position_embeddings / max_position_embeddings - else: - max_position_embeddings = config.max_position_embeddings - expanded_max_position_embeddings = max_position_embeddings * factor - - # Sets the attention factor as suggested in the paper - if attention_factor is None: - if factor <= 1.0: - attention_factor = 1.0 - else: - attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) - - # Compute the inverse frequencies -- scaled based on the target sequence length - if expanded_max_position_embeddings > max_position_embeddings: - ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) - else: - ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) - inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim - inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) - - return inv_freq, attention_factor - - -def _compute_llama3_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs -) -> Tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies for llama 3.1. - - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) - - factor = config.rope_scaling["factor"] # `8` in the original implementation - low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation - high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation - old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - wavelen = 2 * math.pi / inv_freq - # wavelen < high_freq_wavelen: do nothing - # wavelen > low_freq_wavelen: divide by factor - inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) - # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama - is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) - inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - - return inv_freq_llama, attention_factor - - -# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters -# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE -# parameterizations, as long as the callable has the same signature. -ROPE_INIT_FUNCTIONS = { - "default": _compute_default_rope_parameters, - "linear": _compute_linear_scaling_rope_parameters, - "dynamic": _compute_dynamic_ntk_parameters, - "yarn": _compute_yarn_parameters, - "longrope": _compute_longrope_parameters, - "llama3": _compute_llama3_parameters, -} - - -def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): - """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" - # BC: "rope_type" was originally "type" -- let's gracefully handle it - if "rope_type" not in received_keys and "type" in received_keys: - received_keys -= {"type"} - received_keys.add("rope_type") - - missing_keys = required_keys - received_keys - if missing_keys: - raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") - - if optional_keys is not None: - unused_keys = received_keys - required_keys - optional_keys - else: - unused_keys = received_keys - required_keys - if unused_keys: - logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") - - -def _validate_default_rope_parameters(config: PretrainedConfig): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) - - -def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - -def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor"} - # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` - optional_keys = {"original_max_position_embeddings"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - -def _validate_yarn_parameters(config: PretrainedConfig): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor"} - optional_keys = {"attention_factor", "beta_fast", "beta_slow"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - beta_fast = rope_scaling.get("beta_fast") - if beta_fast is not None and not isinstance(beta_fast, float): - logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - beta_slow = rope_scaling.get("beta_slow") - if beta_slow is not None and not isinstance(beta_slow, float): - logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - - if (beta_fast or 32) < (beta_slow or 1): - logger.warning( - f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " - f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" - ) - - -def _validate_longrope_parameters(config: PretrainedConfig): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "short_factor", "long_factor"} - # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` - optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) - - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - short_factor = rope_scaling.get("short_factor") - if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): - logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") - if not len(short_factor) == dim // 2: - logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") - - long_factor = rope_scaling.get("long_factor") - if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): - logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") - if not len(long_factor) == dim // 2: - logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") - - # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over - # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is - # unique to longrope (= undesirable) - if hasattr(config, "original_max_position_embeddings"): - logger.warning_once( - "This model has set a `original_max_position_embeddings` field, to be used together with " - "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" - "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " - "as it is compatible with most model architectures." - ) - else: - factor = rope_scaling.get("factor") - if factor is None: - logger.warning("Missing required keys in `rope_scaling`: 'factor'") - elif not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - - -def _validate_llama3_parameters(config: PretrainedConfig): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - low_freq_factor = rope_scaling["low_freq_factor"] - high_freq_factor = rope_scaling["high_freq_factor"] - if low_freq_factor is None or not isinstance(low_freq_factor, float): - logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") - if high_freq_factor is None or not isinstance(high_freq_factor, float): - logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") - if high_freq_factor <= low_freq_factor: - logger.warning( - "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" - f"{high_freq_factor} and low_freq_factor={low_freq_factor}" - ) - - original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] - if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): - logger.warning( - "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " - f"{original_max_position_embeddings}" - ) - if original_max_position_embeddings >= config.max_position_embeddings: - logger.warning( - "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " - f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" - ) - - -# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. -ROPE_VALIDATION_FUNCTIONS = { - "default": _validate_default_rope_parameters, - "linear": _validate_linear_scaling_rope_parameters, - "dynamic": _validate_dynamic_scaling_rope_parameters, - "yarn": _validate_yarn_parameters, - "longrope": _validate_longrope_parameters, - "llama3": _validate_llama3_parameters, -} - - -def rope_config_validation(config: PretrainedConfig): - """ - Validate the RoPE config arguments, given a `PretrainedConfig` object - """ - rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` - if rope_scaling is None: - return - - # BC: "rope_type" was originally "type" - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) - validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) - if validation_fn is not None: - validation_fn(config) - else: - logger.warning( - f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" - ) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py deleted file mode 100644 index 7dc65a0923..0000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py +++ /dev/null @@ -1,447 +0,0 @@ -# coding=utf-8 -# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - - -class Llama4VisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a - Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the Llama4 109B. - - e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - hidden_size (`int`, *optional*, defaults to 768): - Dimensionality of the encoder layers and the pooler layer. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. - num_hidden_layers (`int`, *optional*, defaults to 34): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer encoder. - num_channels (`int`, *optional*, defaults to 3): - Number of channels in the input image. - intermediate_size (`int`, *optional*, defaults to 5632): - Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. - vision_output_dim (`int`, *optional*, defaults to 7680): - Dimensionality of the vision model output. Includes output of transformer - encoder with intermediate layers and global transformer encoder. - image_size (`int`, *optional*, defaults to 448): - The size (resolution) of each image *tile*. - patch_size (`int`, *optional*, defaults to 14): - The size (resolution) of each patch. - norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the layer normalization layers. - vision_feature_layer (``, *optional*, defaults to -1): TODO - vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO - projector_input_dim (`int`, *optional*, defaults to 4096): TODO - projector_output_dim (`int`, *optional*, defaults to 4096): TODO - multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO - projector_dropout (`int`, *optional*, defaults to 0.0): TODO - attention_dropout (`int`, *optional*, defaults to 0.0): TODO - rope_theta (`int`, *optional*, defaults to 10000): TODO - """ - - base_model_tp_plan = { - "model.layers.*.self_attn.q_proj": "colwise", - "model.layers.*.self_attn.k_proj": "colwise", - "model.layers.*.self_attn.v_proj": "colwise", - "model.layers.*.self_attn.o_proj": "rowwise", - "vision_adapter.mlp.fc1": "colwise", - "vision_adapter.mlp.fc2": "rowwise", - "patch_embedding.linear": "colwise_rep", - } - model_type = "llama4_vision_model" - base_config_key = "vision_config" - - def __init__( - self, - hidden_size: int = 768, - hidden_act: str = "gelu", - num_hidden_layers: int = 34, - num_attention_heads: int = 16, - num_channels: int = 3, - intermediate_size: int = 5632, - vision_output_dim: int = 7680, - image_size: int = 448, - patch_size: int = 14, - norm_eps: float = 1e-5, - vision_feature_layer=-1, - vision_feature_select_strategy="default", - initializer_range: float = 0.02, - pixel_shuffle_ratio=0.5, - projector_input_dim=4096, - projector_output_dim=4096, - multi_modal_projector_bias=False, - projector_dropout=0.0, - attention_dropout=0.0, - rope_theta=10000, - **kwargs, - ): - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.num_hidden_layers = num_hidden_layers - self.num_channels = num_channels - self.intermediate_size = intermediate_size - self.image_size = image_size - self.vision_output_dim = vision_output_dim - self.patch_size = patch_size - self.norm_eps = norm_eps - self.num_attention_heads = num_attention_heads - self.initializer_range = initializer_range - self.pixel_shuffle_ratio = pixel_shuffle_ratio - self.projector_input_dim = projector_input_dim - self.projector_output_dim = projector_output_dim - self.multi_modal_projector_bias = multi_modal_projector_bias - self.projector_dropout = projector_dropout - self.attention_dropout = attention_dropout - self.vision_feature_layer = vision_feature_layer - self.vision_feature_select_strategy = vision_feature_select_strategy - self.rope_theta = rope_theta - super().__init__(**kwargs) - - -class Llama4TextConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a - Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the Llama4 109B. - - e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - vocab_size (`int`, *optional*, defaults to 202048): - Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented - by the `inputs_ids` passed when calling [`Llama4TextModel`]. - hidden_size (`int`, *optional*, defaults to 5120): - Dimensionality of the embeddings and hidden states. - intermediate_size (`int`, *optional*, defaults to 8192): - Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. - intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO - num_hidden_layers (`int`, *optional*, defaults to 48): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 40): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 8): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If not - specified, will default to `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 128): TODO - hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the encoder and pooler. - max_position_embeddings (`int`, *optional*, defaults to 131072): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions. - pad_token_id (`int`, *optional*, defaults to 128004): - The id of the padding token. - bos_token_id (`int`, *optional*, defaults to 1): - The id of the beginning of sentence token. - eos_token_id (`int`, *optional*, defaults to 2): - The id of the end of sentence token. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to `500000.0`): - The base period of the RoPE embeddings. - attention_dropout (`int`, *optional*, defaults to 0.0): TODO - num_experts_per_tok (`int`, *optional*, defaults to 1): TODO - num_local_experts (`int`, *optional*, defaults to 16): TODO - moe_layers (`int`, *optional*): TODO - interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO - use_qk_norm (`int`, *optional*, defaults to `True`): TODO - output_router_logits (`int`, *optional*, defaults to `False`): TODO - router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO - router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - - - no_rope_layers (`int`, *optional*): TODO - no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO - attention_chunk_size (`int`, *optional*, defaults to 8192): - - attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO - floor_scale (`int`, *optional*, defaults to 8192): TODO - attn_scale (`int`, *optional*, defaults to 0.1): TODO - cache_implementation (``, *optional*, defaults to `"hybrid"`): - - Example: - """ - - model_type = "llama4_text" - keys_to_ignore_at_inference = ["past_key_values"] - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.input_layernorm.weight": "sequence_parallel", - "layers.*.post_attention_layernorm.weight": "sequence_parallel", - "norm.weight": "sequence_parallel", - "layers.*.feed_forward.shared_expert.gate_proj": "local_colwise", - "layers.*.feed_forward.shared_expert.up_proj": "local_colwise", - "layers.*.feed_forward.shared_expert.down_proj": "local_rowwise", - "layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear - "layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear - "layers.*.feed_forward.experts": "local", - "layers.*.feed_forward.gate_proj": "local_colwise", - "layers.*.feed_forward.up_proj": "local_colwise", - "layers.*.feed_forward.down_proj": "local_rowwise", - "layers.*.feed_forward": "gather", - } - - def __init__( - self, - vocab_size=202048, - hidden_size=5120, - intermediate_size=8192, - intermediate_size_mlp=16384, - num_hidden_layers=48, - num_attention_heads=40, - num_key_value_heads=8, - head_dim=128, - hidden_act="silu", - max_position_embeddings=4096 * 32, - initializer_range=0.02, - rms_norm_eps=1e-5, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - rope_theta=500000, - attention_dropout=0.0, - num_experts_per_tok=1, - num_local_experts=16, - moe_layers=None, - interleave_moe_layer_step=1, - use_qk_norm=True, - output_router_logits=False, - router_aux_loss_coef=0.001, - router_jitter_noise=0.0, - rope_scaling=None, - no_rope_layers=None, - no_rope_layer_interval=4, - attention_chunk_size=8192, - attn_temperature_tuning=4, - floor_scale=8192, - attn_scale=0.1, - cache_implementation="hybrid", - **kwargs, - ): - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - self.attn_temperature_tuning = attn_temperature_tuning - self.attn_scale = attn_scale - self.floor_scale = floor_scale - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.intermediate_size_mlp = intermediate_size_mlp - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.rope_scaling = rope_scaling - self.attention_bias = False - self.cache_implementation = cache_implementation - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - self.use_qk_norm = use_qk_norm - - self.num_experts_per_tok = num_experts_per_tok - self.num_local_experts = num_local_experts - - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - self.router_jitter_noise = router_jitter_noise - default_no_rope_layers = [ - int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers) - ] - - # no_rope_layers == [] is invalid as we cannot have 0 layers - self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers - - self.interleave_moe_layer_step = interleave_moe_layer_step - self.moe_layers = ( - moe_layers - if moe_layers is not None - else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step)) - ) - self.attention_chunk_size = attention_chunk_size - - -class Llama4Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an - Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the Llama4 109B. - - e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vision_config (`Llama4VisionConfig`, *optional*): - The Llama4 Vision config. - text_config (`Llama4TextConfig`, *optional*): - The Llama4 Text config. - boi_token_index (`int`, *optional*, defaults to 200080): - The begin-of-image token index to wrap the image prompt. - eoi_token_index (`int`, *optional*, defaults to 200081): - The end-of-image token index to wrap the image prompt. - image_token_index (`int`, *optional*, defaults to 200092): - The image token index to encode the image prompt. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - - ```python - >>> from transformers import Llama4Model, Llama4Config - - >>> # Initializing a Llama4 7B style configuration - >>> configuration = Llama4Config() - - >>> # Initializing a model from the Llama4 7B style configuration - >>> model = Llama4Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "llama4" - sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig} - base_model_tp_plan = { - "multi_modal_projector.linear_1": "colwise_rep", - } - - def __init__( - self, - vision_config=None, - text_config=None, - boi_token_index=200080, - eoi_token_index=200081, - image_token_index=200092, - tie_word_embeddings=False, - **kwargs, - ): - if vision_config is None: - self.vision_config = Llama4VisionConfig() - logger.info("vision_config is None, using default llama4 vision config") - elif isinstance(vision_config, dict): - self.vision_config = Llama4VisionConfig(**vision_config) - elif isinstance(vision_config, Llama4VisionConfig): - self.vision_config = vision_config - - self.boi_token_index = boi_token_index - self.eoi_token_index = eoi_token_index - self.image_token_index = image_token_index - if text_config is None: - self.text_config = Llama4TextConfig() - logger.info("text_config is None, using default llama4 text config") - elif isinstance(text_config, dict): - self.text_config = Llama4TextConfig(**text_config) - elif isinstance(text_config, Llama4TextConfig): - self.text_config = text_config - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - - -__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"] diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index cc81f4f887..a7ed3f7d37 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -422,7 +422,7 @@ def _get_last_checkpoint_from_each_experiment( "If you are Ido Galil, tell Tomer that you got this exception ;) " ) - # Filter out non-DeciLM checkpoints (e.g., unconverted Llama checkpoints) + # Filter out checkpoints without block_configs (e.g. unconverted raw HF layouts) valid_checkpoint_dirs = [ cp for cp in checkpoint_dirs diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py index f0d5bb0583..c1eb0b9b48 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -28,7 +28,6 @@ from transformers import PretrainedConfig, PreTrainedModel from modelopt.torch.puzzletron.anymodel.converter.converter import Converter -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( extract_block_configs_and_locations, parse_layer_replacement, @@ -76,7 +75,7 @@ def _ensure_all_checkpoints_are_split(self) -> None: assert len(unsplit_checkpoints) == 0, f"Found unsplit checkpoints: {unsplit_checkpoints}" @property - def model_config(self) -> DeciLMConfig: + def model_config(self) -> PretrainedConfig: if self._model_config is None: trust_remote_code = self.descriptor.requires_trust_remote_code() self._model_config = load_model_config( diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index 1dcc6c1b22..cb178e0566 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -43,7 +43,6 @@ FFNConfig, SubblockConfig, ) -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement from modelopt.torch.puzzletron.subblock_stats.calc_subblock_params_and_memory import ( calc_subblock_active_params, @@ -300,7 +299,7 @@ def calculate_subblock_stats_for_puzzle_dir( model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code) # Get language model config for LM-specific attributes (VL models have nested config) lm_config = descriptor.get_language_model_config(model_config) - subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes, model_config) + subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes) subblock_stats_file = master_puzzle_dir / subblock_stats_filename if subblock_stats_file.exists() and not merge_with_existing_stats: @@ -383,7 +382,7 @@ def calculate_subblock_stats_for_puzzle_dir( def _load_subblock_configs( - master_puzzle_dir: Path, ffn_hidden_sizes: ListConfig, model_config: DeciLMConfig + master_puzzle_dir: Path, ffn_hidden_sizes: ListConfig ) -> list[SubblockConfig]: try: subblock_configs = _load_subblock_configs_from_replacement_library(master_puzzle_dir) diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index b30e7eefa9..6b98d36a0e 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -29,6 +29,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from transformers import PretrainedConfig from typeguard import check_type from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( @@ -37,7 +38,6 @@ _get_dataclass_type, _is_dataclass_type, ) -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.pruning.pruning_utils import ( ACTIVATIONS_LOG, GQAInitMode, @@ -72,8 +72,8 @@ def _process_single_layer( descriptor, parent_state_dict: dict, new_state_dict: dict, - original_config: DeciLMConfig, - new_config: DeciLMConfig, + original_config: PretrainedConfig, + new_config: PretrainedConfig, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, mlp_init_config: Optional[dict[str, Any]], @@ -336,8 +336,8 @@ def create_child_state_dict( descriptor, original_state_dict: dict, new_state_dict: dict, - original_config: DeciLMConfig, - new_config: DeciLMConfig, + original_config: PretrainedConfig, + new_config: PretrainedConfig, gqa_init_mode: GQAInitMode, ignore_fn: IgnoreFn = default_ignore_fn, mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, @@ -667,11 +667,11 @@ def _init_mlp( *, mlp_init_mode: Union[MlpInitMode, str], layer_idx: int, - original_config: DeciLMConfig, + original_config: PretrainedConfig, mlp_init_config: Optional[dict[str, Any]], original_state_dict: dict, new_state_dict: dict, - new_config: DeciLMConfig, + new_config: PretrainedConfig, keys: dict[str, str], ignored_keys: set[str], expert_idx: Optional[int] = None, @@ -749,7 +749,7 @@ def _prune_experts_by_score( def _init_linear_attn( parent_state_dict: dict[str, torch.Tensor], - parent_config: DeciLMConfig, + parent_config: PretrainedConfig, layer_idx: int, v_key: str, o_key: str, @@ -788,9 +788,9 @@ def _init_linear_mlp(teacher_mlp_state_dict: dict[str, torch.Tensor]) -> torch.T def update_model_config( - model_config: DeciLMConfig, + model_config: PretrainedConfig, model_config_overrides: None | list[dict[str, Any]] | str | dict | Path = None, -) -> DeciLMConfig: +) -> PretrainedConfig: new_model_config = deepcopy(model_config) if model_config_overrides is None: return new_model_config @@ -893,8 +893,8 @@ def _parse_model_config_overrides( def _apply_hidden_size_pruning( out_state_dict: dict[str, torch.Tensor], original_state_dict: dict[str, torch.Tensor], - new_config: DeciLMConfig, - original_config: DeciLMConfig, + new_config: PretrainedConfig, + original_config: PretrainedConfig, descriptor, hidden_size_init_mode: HiddenSizeInitMode, channel_importance_path: Optional[str] = None, diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py index 20c2fbe2ac..0ef4bfa472 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py @@ -14,9 +14,7 @@ # limitations under the License. # mypy: ignore-errors -"""It provides general utilities for loading and initializing PyTorch model checkpoints, -particularly for DeciLM models. -""" +"""Utilities for loading and initializing PyTorch model checkpoints (AnyModel / HF layouts).""" import concurrent.futures import warnings @@ -136,7 +134,7 @@ def skip_init(module_cls, *args, **kwargs) -> nn.Module: def is_valid_decilm_checkpoint(checkpoint_dir: Path | str, trust_remote_code: bool = False) -> bool: - """Validate that a checkpoint is in DeciLM format (has block_configs). + """True if the checkpoint config loads and defines ``block_configs`` (AnyModel / puzzletron layout). Args: checkpoint_dir: Path to checkpoint directory @@ -145,13 +143,13 @@ def is_valid_decilm_checkpoint(checkpoint_dir: Path | str, trust_remote_code: bo trust the source of the model. Defaults to False for security. Returns: - True if checkpoint is valid DeciLM format, False otherwise + True if the config has ``block_configs``, False otherwise """ try: model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) if model_config.block_configs is None: warnings.warn( - f"Skipping checkpoint '{checkpoint_dir}' - not in DeciLM format (missing block_configs)" + f"Skipping checkpoint '{checkpoint_dir}' - missing block_configs (not an AnyModel-style layout)" ) return False return True diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 54e2bdafd5..d2bbc4aca1 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -15,11 +15,11 @@ # mypy: ignore-errors """ -Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, -particularly for DeciLM models. +Utilities for loading and saving Hugging Face-format checkpoints (``AutoConfig`` + optional ``block_configs``). """ import concurrent.futures +import contextlib import dataclasses import fcntl import os From e508b76ae2de9a4a60978a521c63e55f98a47410 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 24 Mar 2026 16:01:05 +0100 Subject: [PATCH 51/62] code clean up (#1110) ### What does this PR do? Remove hardcoded trust_remote_code=true ## Summary by CodeRabbit ## Release Notes * **Documentation** * Clarified documentation for parameter counting in model layer calculations. * **Bug Fixes** * Remote code execution is now disabled by default in model configuration loading. Explicitly enable `trust_remote_code` when using custom modeling code. Signed-off-by: Daniel Korzekwa --- .../calc_subblock_params_and_memory.py | 23 +++++++------------ .../puzzletron/tools/checkpoint_utils_hf.py | 8 +++++-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py index a93e40978f..3ea57bd7a7 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py @@ -105,17 +105,7 @@ def calculate_subblock_params( layer_config: BlockConfig | FFNConfig | AttentionConfig, descriptor: Type[ModelDescriptor], ) -> int: - """Count parameters on one meta decoder layer (puzzletron ``calculate_subblock_params`` parity). - - Unlike ``puzzletron_patcher`` during ``__init__``, we do **not** use ``deci_x_patcher`` here: - for models such as GPT-OSS, Transformers ``post_init`` validates ``_keep_in_fp32_modules`` - against the module tree; replacing norms / attn / mlp with no-op placeholders **before** - ``post_init`` raises (e.g. ``post_attention_layernorm`` … not part of the modules). - - With ``num_hidden_layers == 1`` we merge ``block_config_to_layer_overrides`` into the LM config - (what the patcher would pass into ``DecoderLayer.__init__``), build a stock layer, run - ``post_init``, then apply ``attn_no_op_post_init`` / ``mlp_no_op_post_init`` for param counting. - """ + """Count parameters on one meta decoder layer.""" if isinstance(layer_config, FFNConfig): block_config = layer_config.to_blockconfig() elif isinstance(layer_config, AttentionConfig): @@ -147,12 +137,15 @@ def calculate_subblock_params( # Replaced earlier pattern: # with EmptyInitOnDevice("meta"), deci_x_patcher(..., block_configs=block_configs): # model = init_model_from_config(_config, ...) + # # That fails on GPT-OSS with recent Transformers: ``deci_x_patcher`` runs # ``attn_no_op_post_init`` / ``mlp_no_op_post_init`` inside ``DecoderLayer.__init__``, so norms - # / attn / mlp are swapped for placeholders before ``GptOssModel.post_init`` runs; ``post_init`` - # then raises ``ValueError`` (e.g. ``post_attention_layernorm`` in ``_keep_in_fp32_modules`` no - # longer matches the tree). Below we merge per-layer fields manually, init without the patcher, - # then call the same descriptor no-op hooks on the built layer (equivalent param count for + # / attn / mlp are swapped for placeholders before ``GptOssModel.__init__`` finishes. At the end + # of ``GptOssModel.__init__`` the stack calls ``self.post_init()`` — inherited from + # ``PreTrainedModel`` — which then raises + # ``ValueError`` (e.g. ``post_attention_layernorm`` in ``_keep_in_fp32_modules`` no longer matches + # the tree). Below we merge per-layer fields manually, init without the patcher, then call the + # same descriptor no-op hooks on the built layer (equivalent param count for # ``num_hidden_layers == 1``). # ``block_config_to_layer_overrides`` may include keys with value ``None``; we omit those so diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index d2bbc4aca1..38dfaaf00b 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -138,10 +138,14 @@ def _get_model_class_from_config(config: PretrainedConfig) -> type: def init_model_from_config( config: PretrainedConfig, *, - trust_remote_code: bool = True, + trust_remote_code: bool = False, **kwargs, ) -> PreTrainedModel: - """Build a model from config on meta/uninitialized weights (used e.g. for subblock param counts).""" + """Build a model from config on meta/uninitialized weights (used e.g. for subblock param counts). + + ``trust_remote_code`` defaults to False (only ``AutoModelForCausalLM.from_config`` uses it). + Pass True when loading configs that rely on custom modeling code from the checkpoint. + """ model_class = _get_model_class_from_config(config) if model_class is AutoModelForCausalLM: return model_class.from_config(config, trust_remote_code=trust_remote_code, **kwargs) From 2f55c7303b56b75599ece8f383e9531a41247c4e Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 25 Mar 2026 17:15:14 +0100 Subject: [PATCH 52/62] Dkorzekwa/puzzletron use importance hooks from prune (#1115) ### What does this PR do? Use pruning importance hooks from prune/importance_hooks/ in modelopt.torch.puzzletron --------- Signed-off-by: Daniel Korzekwa --- .../pruning/ffn_pruning.yaml | 2 +- .../pruning/ffn_pruning.yaml | 2 +- .../pruning/ffn_pruning.yaml | 2 +- .../pruning/ffn_pruning.yaml | 2 +- .../pruning/ffn_pruning.yaml | 2 +- .../pruning/ffn_pruning.yaml | 2 +- .../pruning/ffn_pruning.yaml | 2 +- .../nas/plugins/megatron_hooks/__init__.py | 23 - .../nas/plugins/megatron_hooks/base_hooks.py | 1180 ----------------- .../megatron_hooks/base_hooks_analysis.py | 104 -- .../megatron_hooks/compare_module_outputs.py | 291 ---- .../plugins/megatron_hooks/megatron_hooks.py | 36 - .../torch/prune/importance_hooks/__init__.py | 1 + .../prune/importance_hooks/base_hooks.py | 7 +- .../importance_hooks/expert_removal_hooks.py | 387 ++++++ .../activation_hooks/utils.py | 2 +- modelopt/torch/puzzletron/anymodel/README.md | 10 +- .../pruning/expert_removal_pruning_mixin.py | 4 +- .../pruning/ffn_intermediate_pruning_mixin.py | 2 +- .../pruning/kv_heads_pruning_mixin.py | 2 +- .../torch/puzzletron/pruning/pruning_mixin.py | 2 +- .../puzzletron/utils/checkpoint_manager.py | 4 +- modelopt/torch/utils/robust_json.py | 5 + .../plugins/megatron_hooks/test_base_hooks.py | 100 -- .../test_base_hooks_analysis.py | 173 --- .../pruning/expert_pruning.yaml | 2 +- .../pruning/expert_pruning.yaml | 2 +- .../pruning/ffn_pruning.yaml | 2 +- .../gpt-oss-20b/pruning/expert_removal.yaml | 2 +- .../configs/pruning/attn_pruning.yaml | 2 +- .../configs/pruning/ffn_pruning_base.yaml | 2 +- .../test_mcore_gpt_minitron_pruning.py | 47 - 32 files changed, 423 insertions(+), 1983 deletions(-) delete mode 100644 modelopt/torch/nas/plugins/megatron_hooks/__init__.py delete mode 100644 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py delete mode 100644 modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py delete mode 100644 modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py delete mode 100644 modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py create mode 100644 modelopt/torch/prune/importance_hooks/expert_removal_hooks.py delete mode 100644 tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py delete mode 100644 tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml index 8b19e167d0..258e6c38a3 100644 --- a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml @@ -10,7 +10,7 @@ pruning_mixin: _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor target_name: "mlp.router" -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.RankedChoiceVotingHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.RankedChoiceVotingHook} activation_hooks_kwargs: # Additional kwargs to pass to the hook init num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml index aa857c5ace..da0b972070 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -6,7 +6,7 @@ pruning_mixin: layer_descriptor: _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml index a58c42c521..05de8bfdcc 100644 --- a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -6,7 +6,7 @@ pruning_mixin: layer_descriptor: _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml index 0982d90aa8..5fb7fcbdd2 100644 --- a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml @@ -8,7 +8,7 @@ pruning_mixin: layer_descriptor: _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activation_hooks_kwargs: method: iterative target_layer: "mlp.down_proj" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml index 60e421b239..1e2ecf07a0 100644 --- a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml @@ -8,7 +8,7 @@ pruning_mixin: layer_descriptor: _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activation_hooks_kwargs: method: iterative target_layer: "mixer.down_proj" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml index 6a5922959d..18d7e234ac 100644 --- a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml @@ -8,7 +8,7 @@ pruning_mixin: layer_descriptor: _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activation_hooks_kwargs: method: iterative target_layer: "mlp.down_proj" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml index 0b6fa59fbf..70dd5fd006 100644 --- a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml @@ -8,7 +8,7 @@ pruning_mixin: layer_descriptor: _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor.Qwen3_8BFFNIntermediateLayerDescriptor -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activation_hooks_kwargs: method: iterative target_layer: "mlp.down_proj" diff --git a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py deleted file mode 100644 index 996d531392..0000000000 --- a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Forward hooks for estimating importance scores for pruning.""" - -from modelopt.torch.utils import import_plugin - -from .base_hooks import * -from .base_hooks_analysis import * - -with import_plugin("megatron_hooks"): - from .megatron_hooks import * diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py deleted file mode 100644 index a868fddc13..0000000000 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py +++ /dev/null @@ -1,1180 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors -"""Forward hooks for activation-based importance estimation.""" - -import gc -import json -from abc import ABC, abstractmethod -from datetime import datetime -from pathlib import Path - -import torch -import torch.nn.functional as F -from omegaconf import DictConfig, OmegaConf -from torch import nn - -import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 -from modelopt.torch.puzzletron.tools.logger import aprint -from modelopt.torch.puzzletron.tools.robust_json import json_dump - -__all__ = [ - "ForwardHook", - "IndependentChannelContributionHook", - "IndependentKvHeadContributionHook", - "IterativeChannelContributionHook", - "L2NormHook", - "LayerNormContributionHook", -] - - -def clear_gpu_memory(clear: bool) -> None: - """Clear GPU memory cache if requested. - - Args: - clear: If True, runs garbage collection and empties CUDA cache. - """ - if clear: - gc.collect() - torch.cuda.empty_cache() - - -class ForwardHook(ABC): - """Base class for PyTorch forward hooks. - - This follows the PyTorch forward hook API where the second - parameter is 'args' (a tuple of positional arguments passed to forward()). - - Usage: - hook = MyHook() - module.register_forward_hook(hook) - """ - - @abstractmethod - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor - ) -> None: - """Forward hook that is called after the module's forward pass. - - Args: - module: The module this hook is registered on - args: Tuple of positional arguments passed to module.forward() - output: The output from module.forward() - - Returns: - None (does not modify the output) - """ - ... - - @abstractmethod - def accumulate(self) -> torch.Tensor: - """Return accumulated importance scores. - - This method should be called after all forward passes to retrieve - the final importance scores for each channel/feature. - - Returns: - Tensor of importance scores, one per channel/feature. - - Raises: - AssertionError: If no activations have been collected yet. - """ - ... - - @abstractmethod - def state_dict(self) -> dict: - """Return the internal state for checkpointing. - - Returns: - dict: State dictionary containing checkpoint data. - Can contain tensors, ints, lists, etc. - """ - ... - - @abstractmethod - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint. - - Args: - state_dict: State dictionary previously returned by state_dict() - """ - ... - - def get_progress_info(self) -> dict: - """Get progress information for this hook. - - Returns: - dict: Progress information (e.g., current iteration, samples processed). - Default implementation returns empty dict. - """ - return {} - - @abstractmethod - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert hook results to dictionary format for saving. - - Returns: - dict: Dictionary containing result tensors (e.g., "score", "channels_importance_ascending"). - """ - ... - - @classmethod - def dump_activations_logs( - cls: type["ForwardHook"], - activation_hooks: dict[str, "ForwardHook"], - activations_log_dir: Path | str, - args: DictConfig, - ) -> None: - """Default implementation for dumping final activation scores logs to disk. - - This is called only at the end of scoring to save final results. - """ - activations_log_dir = Path(activations_log_dir) - activations_log_dir.mkdir(exist_ok=True, parents=True) - rank = dist.rank() - activations_log_path = activations_log_dir / f"rank_{rank}.pth" - activations_log = { - module_name: hook.to_dict() for module_name, hook in activation_hooks.items() - } - torch.save(activations_log, activations_log_path) - - if rank == 0: - if args.activation_hooks_kwargs is not None: - args.activation_hooks_kwargs.pop("model", None) - json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") - dist.barrier() - - aprint(f"Dumped final activations log to {activations_log_path}") - - @classmethod - def save_hook_states( - cls: type["ForwardHook"], - activation_hooks: dict[str, "ForwardHook"], - activations_log_dir: Path | str, - ) -> None: - """Save hook states for checkpointing (separate from final results). - - This can be called periodically during scoring. - Note: Synchronization should be handled at a higher level to avoid deadlocks. - """ - activations_log_dir = Path(activations_log_dir) - activations_log_dir.mkdir(exist_ok=True, parents=True) - rank = dist.rank() - - hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth" - hook_states = { - module_name: hook.state_dict() for module_name, hook in activation_hooks.items() - } - torch.save(hook_states, hook_states_path) - - -class L2NormHook(ForwardHook): - """Hook for accumulating activation statistics for importance estimation. - - Activations are computed as mean over seq_len and then squared and summed over batch_size. - In the accumulate() method we take the square root of the sum to get the L2 norm. - - This is the base version without tensor parallelism support. - For megatron with TP > 1, use MegatronL2NormHook instead. - - Args: - max_size: Optional maximum expected size to validate against (skips if mismatch). - Useful for skipping non-max subnets during profiling. - """ - - def __init__(self, max_size: int | None = None): - """Initialize the L2NormHook.""" - self.max_size = max_size - self._activations: torch.Tensor | None = None - - def _get_input_tensor(self, args: tuple[torch.Tensor, ...]) -> torch.Tensor: - """Get input tensor from args. Override in subclass for TP gathering.""" - return args[0].detach() - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor - ) -> None: - """Accumulate activation statistics from the forward pass. - - Args: - module: The module this hook is registered on. - args: Tuple of input tensors. args[0] expected shape: [seq_len, batch_size, hidden_size] - (Megatron sequence-first format). - output: Output tensor from the module's forward pass. - """ - input_tensor = self._get_input_tensor(args) - - if input_tensor.dim() == 2: - # For sparse experts, there is no batch dimension. - input_tensor = input_tensor[:, None, :] - - # Dont aggregate activations from non-max subnets (e.g. from profiling) - if self.max_size is not None and input_tensor.shape[-1] != self.max_size: - return - - input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow - activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size] - activations = activations.pow(2).sum(dim=0) # [hidden_size] - - if self._activations is None: - self._activations = activations - else: - self._activations += activations - - def accumulate(self) -> torch.Tensor: - """Return the accumulated L2 norm of activations. - - Returns: - Tensor of accumulated scores, one per channel - - Raises: - AssertionError: If no activations have been collected yet - """ - assert self._activations is not None, "No activations collected for importance estimation." - # Convert squared sum to L2 norm - return self._activations.pow(0.5) - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert to dict format for saving.""" - return {"score": self.accumulate().cpu()} - - def state_dict(self) -> dict: - """Return the state dictionary containing activations.""" - return {"activations": self._activations} - - def load_state_dict(self, state_dict: dict) -> None: - """Load activations from checkpoint.""" - self._activations = state_dict["activations"] - - -class IndependentChannelContributionHook(ForwardHook): - """Hook for channel importance estimation using weight norms and activation magnitudes. - - Computes channel importance as the product of: - - L2 norm of each column in the weight matrix (how much each input channel affects output) - - Mean absolute activation for each channel (how strongly each channel is activated) - - Args: - linear_layer: The linear projection layer to analyze. Must have a `weight` attribute - and either `in_features` (nn.Linear) or `input_size` (Megatron RowParallelLinear). - max_size: Optional maximum expected size to validate against (skips if mismatch). - Useful for skipping non-max subnets during profiling. - """ - - def __init__( - self, - linear_layer: nn.Module, - max_size: int | None = None, - ): - """Initialize the independent channel contribution hook.""" - self.max_size = max_size - - weight_matrix = linear_layer.weight.float() - self.weight_norm = torch.linalg.vector_norm(weight_matrix, dim=0) - - # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) - if hasattr(linear_layer, "input_size"): - self.num_channels = linear_layer.input_size # Megatron-Core - else: - self.num_channels = linear_layer.in_features # PyTorch - - self.agg_channel_activations = torch.zeros( - size=(self.num_channels,), - dtype=torch.float32, - device=weight_matrix.device, - ) - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple - ) -> None: - """Accumulate mean absolute activations per channel. - - Args: - module: The module this hook is registered on. - args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] - (PyTorch batch-first format). - output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) - for parallel layers. - """ - activations = args[0] - - # Don't aggregate activations from non-max subnets (e.g. from profiling) - if self.max_size is not None and activations.shape[-1] != self.max_size: - return - - mean_abs_channel_activations = ( - activations.abs().float().mean(dim=list(range(activations.ndim - 1))) - ) - self.agg_channel_activations[:] += mean_abs_channel_activations # shape [input_channels] - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert results to dict with channel importance scores. - - Returns: - Dict with "score" (weight_norm * activations), "weight_norm", and - "agg_channel_activations". - """ - return { - "score": (self.weight_norm * self.agg_channel_activations).cpu(), - "weight_norm": self.weight_norm.cpu(), - "agg_channel_activations": self.agg_channel_activations.cpu(), - } - - def accumulate(self) -> torch.Tensor: - """Return importance scores as a tensor. - - Returns: - Tensor of importance scores (weight_norm * activations), one per channel. - """ - return self.to_dict()["score"] - - def state_dict(self) -> dict: - """Save the internal state for checkpointing.""" - return { - "agg_channel_activations": self.agg_channel_activations.cpu().clone(), - "weight_norm": self.weight_norm.cpu().clone(), - } - - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint.""" - self.agg_channel_activations = state_dict["agg_channel_activations"].to( - self.agg_channel_activations.device - ) - # weight_norm should be the same as it's derived from the model weights - # but we can verify it matches - expected_weight_norm = state_dict["weight_norm"].to(self.weight_norm.device) - if not torch.allclose(self.weight_norm, expected_weight_norm, rtol=1e-5): - raise AssertionError( - "weight_norm mismatch during state loading - model weights may have changed" - ) - - -def get_pruning_schedule(num_channels, pruning_iters): - """Spending decreases monotonically when num_channels >= pruning_iters. - - Intervals between spends increase monotonically when pruning_iters > num_channels. - The budget is fully utilized, and there's spending in the last iteration. - num_channels = 10, pruning_iters = 4 ==> [3, 3, 2, 2] - num_channels = 4, pruning_iters = 10 ==> [0, 1, 0, 1, 0, 0, 1, 0, 0, 1] - """ - if num_channels >= pruning_iters: - # Case when budget is greater than or equal to iterations - q = num_channels // pruning_iters # Base spend per iteration - r = num_channels % pruning_iters # Remainder to distribute - - schedule = [] - for i in range(pruning_iters): - if i < r: - # Assign higher spend to earlier iterations - schedule.append(q + 1) - else: - schedule.append(q) - else: - # Case when iterations are greater than budget - schedule = [0] * pruning_iters - for i in range(1, num_channels + 1): - # Distribute spends at positions where intervals increase monotonically - pos = ((i * pruning_iters) // num_channels) - 1 - schedule[pos] = 1 - return schedule - - -class IterativeChannelContributionHook(ForwardHook): - """Hook for iterative channel pruning based on contribution analysis. - - Progressively identifies and removes the least important input channels of a linear layer - by measuring channel contribution as the L2 norm of output change when removed. - - Args: - linear_layer: The linear projection layer to analyze. Must have a `weight` attribute - and either `in_features` (nn.Linear) or `input_size` (Megatron RowParallelLinear). - activation_hooks_kwargs: Configuration dict with: - - validation_full_iters (int): Number of pruning iterations. - - clear_gpu_memory (bool, optional): Clear GPU memory during computation. - - calibration_method (str, optional): "scale_by_magnitude" or None. - max_size: Optional maximum expected size to validate against (skips if mismatch). - Useful for skipping non-max subnets during profiling. - """ - - def __init__( - self, - linear_layer: nn.Module, - activation_hooks_kwargs: dict, - max_size: int | None = None, - ): - """Initialize the iterative channel contribution hook.""" - self.weight_matrix = linear_layer.weight - - # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) - # TODO: Consider better design to handle RowParallelLinear and nn.Linear - if hasattr(linear_layer, "input_size"): - self.num_channels = linear_layer.input_size # Megatron-Core - else: - self.num_channels = linear_layer.in_features # PyTorch - - self.max_size = max_size - self.pruning_iters = activation_hooks_kwargs["validation_full_iters"] - self.clear_gpu_memory = activation_hooks_kwargs.get("clear_gpu_memory", False) - self.curr_iter = 0 - self.pruning_schedule = get_pruning_schedule( - num_channels=self.num_channels, pruning_iters=self.pruning_iters - ) - - self.agg_cont_per_channel = torch.zeros( - size=(self.num_channels,), - dtype=torch.float32, - device=self.weight_matrix.device, - ) - self.pruned_channels = [] - self.calibration_method = activation_hooks_kwargs.get("calibration_method") - self.epsilon = 1e-8 - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple - ) -> None: - """Compute channel contributions and prune channels according to schedule. - - Args: - module: The module this hook is registered on. - args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] - (PyTorch batch-first format). - output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) - for parallel layers. - """ - # Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear) - # TODO: Consider better design to handle RowParallelLinear and nn.Linear - if isinstance(output, tuple): - output_tensor = output[0] - else: - output_tensor = output - - activations = args[0] - - # Don't aggregate activations from non-max subnets (e.g. from profiling) - if self.max_size is not None and activations.shape[-1] != self.max_size: - return - - n_channels_to_prune = self.pruning_schedule[self.curr_iter] - - curr_activations = activations.clone() # Shape B,T,I - curr_activations[..., self.pruned_channels] = 0 - output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E - - if self.calibration_method is None: - scaling_factor_per_token = torch.ones_like(output_tensor[..., 0]) # Shape B,T - elif self.calibration_method == "scale_by_magnitude": - output_norms = torch.linalg.vector_norm(output_tensor, dim=-1) # Shape B,T - output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T - scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon) - del output_curr_norms, output_norms - else: - raise NotImplementedError - del curr_activations - clear_gpu_memory(clear=self.clear_gpu_memory) - - s = scaling_factor_per_token.unsqueeze(-1) * output_tensor - output_curr # Shape: (B, T, E) - s_squared_per_token = torch.sum(s**2, dim=-1) # Shape: (B, T) - b = s @ self.weight_matrix # Shape: (B, T, I) - c = torch.sum(self.weight_matrix**2, dim=0) # Shape: (I) - del s, output_curr - clear_gpu_memory(clear=self.clear_gpu_memory) - - contribution_squared = ( - s_squared_per_token.unsqueeze(2) + 2 * activations * b + (activations**2) * c - ) # Shape: (B, T, I) - del s_squared_per_token, b, c, activations - clear_gpu_memory(clear=self.clear_gpu_memory) - - contribution = torch.sqrt(contribution_squared + self.epsilon) # Shape: (B, T, I) - mean_cont_per_channel = torch.mean(contribution, dim=(0, 1)) # Shape: (I) - mean_cont_per_channel[self.pruned_channels] = torch.inf - del contribution, contribution_squared - clear_gpu_memory(clear=self.clear_gpu_memory) - - self.agg_cont_per_channel += mean_cont_per_channel - if n_channels_to_prune > 0: - _, worst_indices = torch.topk( - self.agg_cont_per_channel, n_channels_to_prune, largest=False - ) - worst_indices_list = worst_indices.tolist() - assert not set(self.pruned_channels).intersection(set(worst_indices_list)) - self.pruned_channels.extend(worst_indices_list) - self.agg_cont_per_channel.zero_() - self.curr_iter += 1 - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert pruning results to dict with channel importance rankings. - - Returns: - Dict with "score" (importance rank per channel) and - "channels_importance_ascending" (channel indices in ascending importance). - """ - assert self.num_channels == len(self.pruned_channels) - channels_importance_ascending = torch.tensor(self.pruned_channels, dtype=torch.long) - score = torch.empty(self.num_channels, dtype=torch.long) - score[channels_importance_ascending] = torch.arange(self.num_channels, dtype=torch.long) - - return { - "score": score.cpu(), - "channels_importance_ascending": channels_importance_ascending.cpu(), - } - - def accumulate(self) -> torch.Tensor: - """Return importance scores as a tensor. - - Returns: - Tensor of importance scores, one per channel. Lower scores indicate less important channels. - """ - return self.to_dict()["score"] - - def state_dict(self) -> dict: - """Save the internal state for checkpointing.""" - return { - "curr_iter": self.curr_iter, - "pruned_channels": self.pruned_channels.copy(), - "agg_cont_per_channel": self.agg_cont_per_channel.cpu().clone(), - "num_channels": self.num_channels, - "pruning_iters": self.pruning_iters, - "pruning_schedule": self.pruning_schedule.copy(), - "calibration_method": self.calibration_method, - "epsilon": self.epsilon, - } - - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint.""" - self.curr_iter = state_dict["curr_iter"] - self.pruned_channels = state_dict["pruned_channels"].copy() - self.agg_cont_per_channel = state_dict["agg_cont_per_channel"].to(self.weight_matrix.device) - # Verify other parameters match - assert self.num_channels == state_dict["num_channels"], "Channel count mismatch" - assert self.pruning_iters == state_dict["pruning_iters"], "Iteration count mismatch" - assert self.pruning_schedule == state_dict["pruning_schedule"], "Pruning schedule mismatch" - - def get_progress_info(self) -> dict: - """Get progress information for this hook. - - Returns: - dict: Progress information including iteration count and pruned channels. - """ - progress = self.curr_iter / self.pruning_iters if self.pruning_iters > 0 else 0.0 - return { - "curr_iter": self.curr_iter, - "total_iters": self.pruning_iters, - "progress": progress, - "pruned_channels_count": len(self.pruned_channels), - "total_channels": self.num_channels, - } - - -class IndependentKvHeadContributionHook(ForwardHook): - """Hook for estimating KV head importance based on contribution analysis. - - Measures the contribution of each KV head group to the output projection - by computing L2 norms of per-head outputs. - - Args: - linear_layer: The output projection layer (o_proj). - activation_hooks_kwargs: Configuration dict with: - - model: The model instance (to get config). - - block_config: Block configuration with attention settings. - - optimize_for (str, optional): "latency" or "memory". Defaults to "memory". - """ - - def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): - """Initialize the KV head contribution hook.""" - model_config = activation_hooks_kwargs["model"].config - block_config = activation_hooks_kwargs["block_config"] - - self.optimize_for = activation_hooks_kwargs.get("optimize_for", "memory") - assert self.optimize_for in ["latency", "memory"] - - self.hidden_size = model_config.hidden_size - self.num_q_heads = model_config.num_attention_heads - self.num_kv_heads = block_config.attention.num_key_value_heads - self.n_heads_in_group = self.num_q_heads // self.num_kv_heads - self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads) - - self.agg_kv_head_contributions = torch.zeros( - size=(self.num_kv_heads,), - dtype=torch.float32, - device=linear_layer.weight.device, - ) - - # Reshape weight matrix to group by KV heads - self.weight_grouped = linear_layer.weight.view( - self.hidden_size, self.num_kv_heads, self.head_dim * self.n_heads_in_group - ).permute((1, 0, 2)) - # weight_grouped.shape: (kv_heads, hidden_dim, head_dim * n_heads_in_group) - - def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: - """Compute KV head contributions from the forward pass.""" - attn_out = args[0] # Shape: (B, T, num_q_heads * head_dim) - batch_size, seq_len, _ = attn_out.shape - - # Reshape attention output to group by KV heads - attn_out_grouped = attn_out.view( - batch_size, - seq_len, - self.num_kv_heads, - self.head_dim * self.n_heads_in_group, - ).unsqueeze(-2) - # attn_out_grouped.shape: (B, T, kv_heads, 1, head_dim * n_heads_in_group) - - if self.optimize_for == "latency": - # Compute contribution per KV head group - # First compute the projection for each KV head group - layer_out_grouped = attn_out_grouped @ self.weight_grouped.transpose(-1, -2) - layer_out_grouped = layer_out_grouped.squeeze(-2) - # layer_out_grouped.shape: (B, T, kv_heads, hidden_dim) - - else: - layer_out_grouped = [] - for i in range(self.num_kv_heads): - _layer_out = attn_out_grouped[:, :, i] @ self.weight_grouped[i].transpose(-1, -2) - layer_out_grouped.append(_layer_out) - layer_out_grouped = torch.cat(layer_out_grouped, dim=2) - - # Compute L2 norm of each group's contribution - contrib_per_kv_head = torch.linalg.vector_norm(layer_out_grouped, dim=-1) - # contrib_per_kv_head.shape: (B, T, kv_heads) - - contrib_per_kv_head = contrib_per_kv_head.mean(dim=(0, 1)) - # contrib_per_kv_head.shape: (kv_heads,) - - # Accumulate contributions - self.agg_kv_head_contributions += contrib_per_kv_head - - def accumulate(self) -> torch.Tensor: - """Return accumulated KV head importance scores. - - Returns: - Tensor of importance scores, one per KV head. - """ - return self.agg_kv_head_contributions - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert to dict format for saving. - - Returns: - Dict with "score" tensor containing KV head importance scores. - """ - return { - "score": self.agg_kv_head_contributions.cpu(), - } - - def state_dict(self) -> dict: - """Return the internal state for checkpointing.""" - raise NotImplementedError("Saving state dict is not supported for this hook.") - - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint.""" - raise NotImplementedError("Loading state dict is not supported for this hook.") - - -class LayerNormContributionHook(ForwardHook): - """Hook for estimating channel importance based on layer normalization activations. - - Aggregates mean absolute activation values per channel for a layer normalization layer. - - Args: - layernorm_layer: The layer normalization layer. - activation_hooks_kwargs: The activation hooks kwargs (not used). - """ - - def __init__(self, layernorm_layer: nn.Module, activation_hooks_kwargs: dict): - """Aggregates mean absolute activation values per channel for a layer normalization layer. - - Args: - layernorm_layer: The layer normalization layer - activation_hooks_kwargs: The activation hooks kwargs (not used) - """ - self.agg_embedding_activations = torch.zeros( - size=(layernorm_layer.weight.shape[0],), - dtype=torch.float32, - device=layernorm_layer.weight.device, - ) - - def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: - """Accumulate activation statistics from the forward pass.""" - self.agg_embedding_activations += ( - output.abs().float().mean(dim=list(range(output.ndim - 1))) - ) - - def accumulate(self) -> torch.Tensor: - """Return accumulated channel importance scores.""" - return self.agg_embedding_activations - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert to dict format for saving.""" - return { - "score": self.agg_embedding_activations.cpu(), - "channels_importance_ascending": self.agg_embedding_activations.sort()[1].cpu(), - } - - def state_dict(self) -> dict: - """Return the internal state for checkpointing.""" - raise NotImplementedError("Saving state dict is not supported for this hook.") - - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint.""" - raise NotImplementedError("Loading state dict is not supported for this hook.") - - @classmethod - def dump_activations_logs( - cls: type["LayerNormContributionHook"], - activation_hooks: dict[str, "ForwardHook"], - activations_log_dir: Path | str, - args: DictConfig, - ) -> None: - """At the end of the default implementation of dumping activation scores to disc. - - Save aggregated channel importance results. - """ - super().dump_activations_logs(activation_hooks, activations_log_dir, args) - - rank = dist.rank() - if rank == 0: - LayerNormContributionHook._save_channel_importance_results( - activation_hooks, activations_log_dir, args - ) - - dist.barrier() - - @staticmethod - def _save_channel_importance_results( - activation_hooks: dict[str, "ForwardHook"], - activations_log_dir: Path | str, - args: DictConfig, - ) -> None: - """Save channel importance results from activation hooks.""" - # Find all activation files (for multi-rank scenarios) - activations_log_dir = Path(activations_log_dir) - activation_files = list(activations_log_dir.glob("rank_*.pth")) - if not activation_files: - aprint(f"Warning: No activation files found in {activations_log_dir}") - return - - # Load and aggregate activation data from all ranks - all_scores = [] - for activation_file in activation_files: - aprint(f"Loading activations from {activation_file}") - activation_data = torch.load(activation_file, map_location="cpu") - - # Extract scores from the activation data - for module_name, hook_data in activation_data.items(): - if "score" in hook_data: - scores = hook_data["score"] - all_scores.append(scores) - aprint(f"Loaded {len(scores)} channel scores from {module_name}") - - if not all_scores: - aprint("Warning: No valid activation data found") - return - - # Average scores across all ranks and modules - avg_scores = torch.stack(all_scores).mean(dim=0) - aprint(f"Averaged {len(all_scores)} score sets into {len(avg_scores)} channels") - - # Create channel importance ranking (descending order) - ranked_channels = torch.argsort(avg_scores, descending=True).tolist() - - # Create output data structure - timestamp = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") - output_data = { - "model_path": getattr(args, "model_name_or_path", "unknown"), - "dataset_path": getattr(args, "dataset_path", "unknown"), - "experiment_id": getattr(args, "experiment_id", f"experiment_{timestamp}"), - "eval_samples": getattr(args, "eval_samples", 0), - "micro_batch_size": getattr(args, "micro_batch_size", 0), - "timestamp": timestamp, - "total_channels": len(ranked_channels), - "channel_importance_ranking": ranked_channels, - "channel_scores": avg_scores.tolist(), - "score_statistics": { - "min": float(avg_scores.min()), - "max": float(avg_scores.max()), - "mean": float(avg_scores.mean()), - "std": float(avg_scores.std()), - }, - } - - # Save the output - output_path = activations_log_dir / "channel_importance_results.json" - aprint(f"Saving channel importance data to {output_path}") - with open(output_path, "w") as f: - json.dump(output_data, f, indent=2) - - # Print summary statistics - aprint("=== Channel Importance Summary ===") - aprint(f"Total channels: {len(ranked_channels)}") - aprint(f"Top 10 most important channels: {ranked_channels[:10]}") - aprint(f"Bottom 10 least important channels: {ranked_channels[-10:]}") - aprint(f"Score range: {avg_scores.min():.4f} to {avg_scores.max():.4f}") - aprint(f"Score mean: {avg_scores.mean():.4f}") - aprint(f"Score std: {avg_scores.std():.4f}") - - -class RemoveExpertsIndependentHook(ForwardHook, ABC): - """Base hook for measuring expert importance in Mixture-of-Experts models. - - This hook measures how much removing each expert affects the model output - by comparing outputs with and without each expert. - """ - - def __init__(self, moe: nn.Module, activation_hooks_kwargs: dict): - """Initialize the hook. - - Args: - moe: The MoE module to analyze - activation_hooks_kwargs: Configuration dict containing block_config - """ - self.moe = moe - block_config: BlockConfig = activation_hooks_kwargs["block_config"] - self.num_local_experts = block_config.ffn.moe.num_local_experts - self.num_experts_per_tok = block_config.ffn.moe.num_experts_per_tok - # tensor of zeros of size num experts - self.diffs = ["mse", "cosine"] - some_param = next(self.moe.parameters()) - self.diffs = { - k: torch.zeros( - size=(self.num_local_experts,), dtype=torch.float32, device=some_param.device - ) - for k in self.diffs - } - self.call_count = 0 - - @abstractmethod - def get_router_logits_and_routed_experts( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor]: - """Extract router logits and expert outputs for measuring expert importance. - - This method is called twice per forward pass: - 1. First call (router_logits=None): Compute original routing and expert outputs - 2. Second call (router_logits provided): Re-run with modified logits (expert disabled) - - Args: - hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) - router_logits: Optional pre-computed router logits. If None, compute from hidden_states. - - Returns: - tuple of (router_logits, routed_experts): - - router_logits: Shape (num_tokens, num_local_experts) - - routed_experts: Shape (num_tokens, hidden_dim) - """ - raise NotImplementedError - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor - ) -> None: - """Forward hook that measures expert importance.""" - hidden_states = args[0] - router_logits, original_routed_out = self.get_router_logits_and_routed_experts( - hidden_states - ) - - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - original_routed_out = original_routed_out.view(-1, original_routed_out.shape[-1]) - - _, router_indices = torch.topk(router_logits, self.num_experts_per_tok, dim=-1) - self.call_count += 1 - - for i_expert in range(self.num_local_experts): - expert_mask = router_indices == i_expert - is_token_routed_to_this_expert = expert_mask.any(dim=-1) - - num_tokens_displaced = is_token_routed_to_this_expert.sum() - if num_tokens_displaced == 0: - continue - num_total_tokens = is_token_routed_to_this_expert.numel() - - relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] - - router_logits_without_i = router_logits.clone() - router_logits_without_i[..., i_expert] = -float("inf") # disable expert i - router_logits_without_i = router_logits_without_i[is_token_routed_to_this_expert, :] - _, routed_out_without_i = self.get_router_logits_and_routed_experts( - relevant_hidden_states, router_logits_without_i - ) - - relevant_tokens_original_out = original_routed_out[is_token_routed_to_this_expert, :] - self.diffs["mse"][i_expert] += ( - nn.functional.mse_loss( - relevant_tokens_original_out, routed_out_without_i, reduction="mean" - ) - * num_tokens_displaced - / num_total_tokens - ) - self.diffs["cosine"][i_expert] += ( - -nn.functional.cosine_similarity( - relevant_tokens_original_out, routed_out_without_i, dim=-1 - ).mean() - * num_tokens_displaced - / num_total_tokens - ) - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert accumulated statistics to dict format.""" - expert_ranks_mse = torch.argsort(self.diffs["mse"]) - expert_ranks_cosine = torch.argsort(self.diffs["cosine"]) - return { - "expert_ranks_mse": expert_ranks_mse.cpu(), - "expert_ranks_cosine": expert_ranks_cosine.cpu(), - "cosine_diffs": (self.diffs["cosine"] / self.call_count).cpu(), - "mse_diffs": (self.diffs["mse"] / self.call_count).cpu(), - } - - def accumulate(self) -> torch.Tensor: - """Return accumulated expert importance scores.""" - return self.diffs["mse"] - - def state_dict(self) -> dict: - """Return the internal state for checkpointing.""" - return { - "diffs_mse": self.diffs["mse"].cpu(), - "diffs_cosine": self.diffs["cosine"].cpu(), - "call_count": self.call_count, - } - - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint.""" - self.diffs["mse"] = state_dict["diffs_mse"].to(self.diffs["mse"].device) - self.diffs["cosine"] = state_dict["diffs_cosine"].to(self.diffs["cosine"].device) - self.call_count = state_dict["call_count"] - - -class NemotronHRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): - """Expert removal importance hook for NemotronH models.""" - - def get_router_logits_and_routed_experts( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor]: - """Extract router logits and expert outputs for NemotronH MoE. - - Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts. - """ - orig_shape = hidden_states.shape - # NemotronHMOE.gate forward, copied to extract router_logits - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if router_logits is None: - router_logits = nn.functional.linear( - hidden_states.type(torch.float32), self.moe.gate.weight.type(torch.float32) - ) - router_logits = router_logits.sigmoid() - router_logits = router_logits + self.moe.gate.e_score_correction_bias.unsqueeze(0) - - topk_indices = self._get_topk_indices_without_correction_bias(router_logits) - topk_weights = router_logits.gather(1, topk_indices) - if self.moe.gate.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.moe.gate.routed_scaling_factor - # Routed experts forward - hidden_states = self.moe.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) - return router_logits, hidden_states - - @torch.no_grad() - def _get_topk_indices_without_correction_bias(self, scores: torch.Tensor) -> torch.Tensor: - """Get topk indices without correction bias. - - Same as NemotronHMOE.gate.get_topk_indices but without adding e_score_correction_bias. - """ - group_scores = ( - scores.view( - -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group - ) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.moe.gate.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand( - -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group - ) - .reshape(-1, self.moe.gate.n_routed_experts) - ) - scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0) - topk_indices = torch.topk(scores_for_choice, k=self.moe.gate.top_k, dim=-1, sorted=False)[1] - return topk_indices - - -class RankedChoiceVotingHook(ForwardHook): - """Hook for ranking experts using ranked choice voting algorithm. - - This hook tracks router decisions and uses ranked choice voting to determine - which experts are least important (can be pruned first). - """ - - def __init__(self, router: nn.Module, activation_hooks_kwargs: dict): - """Initialize the hook. - - Args: - router: The router module (typically nn.Linear) - activation_hooks_kwargs: Configuration dict containing block_config - """ - self.router_argsort: list[torch.Tensor] = [] - block_config: BlockConfig = activation_hooks_kwargs["block_config"] - self.top_k = block_config.ffn.moe.num_experts_per_tok - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor - ) -> None: - """Forward hook that records router decisions. - - Args: - module: The router module - args: Tuple with one tensor entry (B, T, I) - output: Router logits of shape (B, T, E) - """ - router_logits = output[0] if isinstance(output, tuple) else output - num_experts = router_logits.shape[-1] - router_argsort = torch.argsort(router_logits, dim=-1, descending=True) - router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu() - self.router_argsort.append(router_argsort) - - def to_dict(self) -> dict[str, torch.Tensor]: - """Convert accumulated statistics to dict format using ranked choice voting.""" - router_argsort = torch.concat(self.router_argsort, dim=0) - num_tokens, num_experts = router_argsort.shape - - expert_ranks = torch.full((num_experts,), -1) - expert_counts_at_pruning_time = {} - - expert_kept_per_iteration: list[list[int]] = [] - expert_counts_per_iteration: list[dict[int, int]] = [] - - for rank in range(num_experts): - ids, counts = router_argsort[:, : self.top_k].unique(return_counts=True) - ids = ids.tolist() - counts = counts.tolist() - expert_counts = dict(zip(ids, counts)) - - expert_kept_per_iteration.append(ids) - expert_counts_per_iteration.append(expert_counts) - - least_popular_expert, min_count = min(expert_counts.items(), key=lambda tup: tup[1]) - - expert_ranks[least_popular_expert] = rank - expert_counts_at_pruning_time[least_popular_expert] = min_count - aprint(f"#{rank}: router_argsort shape = {router_argsort.shape}") - router_argsort = router_argsort[router_argsort != least_popular_expert].view( - num_tokens, -1 - ) - - zero_shot_expert_counts = torch.zeros((num_experts,), dtype=torch.long) - for expert_id, expert_counts_val in expert_counts_per_iteration[0].items(): - zero_shot_expert_counts[expert_id] = expert_counts_val - - # Compute zero-shot expert ranks (double argsort converts counts to rank positions) - zero_shot_expert_ranks = torch.argsort(torch.argsort(zero_shot_expert_counts)) - - aprint("Done: Returning hook metadata.") - return { - "expert_ranks": expert_ranks, - "zero_shot_expert_ranks": zero_shot_expert_ranks, - "expert_counts_at_pruning_time": expert_counts_at_pruning_time, - "expert_counts_per_iteration": expert_counts_per_iteration, - "top_k": self.top_k, - } - - def accumulate(self) -> torch.Tensor: - """Return accumulated expert ranks.""" - if not self.router_argsort: - return torch.tensor([]) - router_argsort = torch.concat(self.router_argsort, dim=0) - return router_argsort[:, 0].float() - - def state_dict(self) -> dict: - """Return the internal state for checkpointing.""" - return { - "router_argsort": [tensor.cpu().clone() for tensor in self.router_argsort], - "top_k": self.top_k, - } - - def load_state_dict(self, state_dict: dict) -> None: - """Load the internal state from a checkpoint.""" - self.router_argsort = [tensor.cpu() for tensor in state_dict["router_argsort"]] - self.top_k = state_dict["top_k"] - - def get_progress_info(self) -> dict: - """Get progress information.""" - return { - "num_batches_processed": len(self.router_argsort), - "total_tokens_processed": sum(tensor.shape[0] for tensor in self.router_argsort) - if self.router_argsort - else 0, - } - - -class RankedChoiceVotingHookNemotronH(RankedChoiceVotingHook): - """Ranked choice voting hook for NemotronH models. - - In NemotronH, router_logits is an internal temporary state that never leaves - the forward() function. We reconstruct router_logits from the input hidden_states. - """ - - def __call__( - self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor - ) -> None: - """Forward hook that reconstructs router logits from hidden states.""" - hidden_states = args[0] - hidden_states = hidden_states.view(-1, module.config.hidden_size) - router_logits = nn.functional.linear( - hidden_states.type(torch.float32), module.weight.type(torch.float32) - ) - super().__call__(module, args, router_logits) - - -class Qwen3VLRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): - """Expert removal importance hook for Qwen3-VL models.""" - - def get_router_logits_and_routed_experts( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor]: - """Extract router logits and expert outputs for Qwen3-VL MoE. - - Based on Qwen3VLMoeSparseMoe forward pass. - """ - orig_shape = hidden_states.shape - - # Flatten to (num_tokens, hidden_size) for processing - hidden_states_flat = hidden_states.reshape(-1, self.moe.hidden_size) - - if router_logits is None: - router_logits = self.moe.gate(hidden_states_flat) - - routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, router_indices = torch.topk(routing_weights, self.moe.top_k, dim=-1) - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states_flat.dtype) - router_weights = torch.zeros_like(router_logits).scatter_( - 1, router_indices, routing_weights - ) - - # Reshape hidden_states for moe.experts (expects 3D: batch, seq, hidden) - # router_weights and router_indices remain 2D (num_tokens, num_experts) - batch_size = orig_shape[0] if hidden_states.ndim == 3 else 1 - hidden_states_3d = hidden_states_flat.reshape(batch_size, -1, self.moe.hidden_size) - - routed_out = self.moe.experts(hidden_states_3d, router_weights, router_indices) - - # Return in same shape as input - routed_out = routed_out.reshape(*orig_shape) - - return router_logits, routed_out diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py deleted file mode 100644 index dc338a7cfa..0000000000 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py +++ /dev/null @@ -1,104 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Analysis tools for evaluating importance scores from hooks.""" - -import torch -import torch.nn.functional as F -from torch import nn - -__all__ = ["evaluate_importance_scores"] - - -def evaluate_importance_scores( - linear_layer: nn.Linear, - activations_batches: list[torch.Tensor], - importance_scores: torch.Tensor, - prune_ratio: float = 0.2, -) -> dict[str, float]: - """Compute reconstruction error after pruning input channels of a linear layer. - - This function simulates channel pruning by zeroing out input channels identified as - least important, then measures how much the layer's output changes. - - Args: - linear_layer: The linear layer to analyze with shape (out_features, in_features). - For example: nn.Linear(in_features=1024, out_features=4096) - activations_batches: List of input activation tensors. - Each tensor has shape [seq_len, batch_size, in_features]. - The last dimension must match linear_layer.in_features. - Example: List of [16, 8, 1024] tensors - importance_scores: Importance score for each input channel (feature). - Shape: [in_features]. Lower scores = less important. - Example: [1024] tensor with one score per input feature - prune_ratio: Fraction of input channels to prune (default: 0.2 means prune 20%). - - Returns: - Dictionary containing averaged metrics across all activation batches: - - rmse: Root mean squared error between original and pruned output - - cosine_similarity: Cosine similarity between original and pruned output - - num_pruned: Number of input channels pruned - - Example: - >>> layer = nn.Linear(in_features=1024, out_features=4096) - >>> # Collect multiple batches for robust evaluation - >>> activations_list = [torch.randn(16, 8, 1024) for _ in range(100)] - >>> scores = torch.randn(1024) # one score per input feature - >>> metrics = evaluate_importance_scores(layer, activations_list, scores, 0.2) - >>> print(f"RMSE: {metrics['rmse']:.4f}, Pruned: {metrics['num_pruned']} channels") - - Note: - - This simulates pruning (zeros out inputs) without modifying layer weights - - "Channels" refers to INPUT features, not output features - - """ - num_channels = importance_scores.shape[0] - num_to_prune = int(num_channels * prune_ratio) - - # Identify channels to prune (lowest scoring = least important) - _, channels_to_prune = torch.topk(importance_scores, num_to_prune, largest=False) - - # Compute metrics for each batch and average - rmse_values = [] - cosine_values = [] - - for activations in activations_batches: - # Get original output - original_output = linear_layer(activations) - - # Prune by zeroing out identified channels - pruned_activations = activations.clone() - pruned_activations[..., channels_to_prune] = 0 - - # Get pruned output - pruned_output = linear_layer(pruned_activations) - - # Compute metrics for this batch - rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item() - rmse_values.append(rmse) - - # Cosine similarity (flatten to vectors) - original_flat = original_output.reshape(-1) - pruned_flat = pruned_output.reshape(-1) - cosine = F.cosine_similarity( - original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1 - ).item() - cosine_values.append(cosine) - - # Return averaged metrics - return { - "rmse": sum(rmse_values) / len(rmse_values), - "cosine_similarity": sum(cosine_values) / len(cosine_values), - "num_pruned": num_to_prune, - } diff --git a/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py b/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py deleted file mode 100644 index 316aff76ff..0000000000 --- a/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py +++ /dev/null @@ -1,291 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Compare module output tensors from different model variants. - -This module provides: -1. OutputSaveHook - A PyTorch hook to capture module outputs during forward pass -2. Comparison utilities - Compute RMSE and cosine similarity between saved outputs - -Usage Example: --------------- - -Step 1: Capture outputs from multiple layers: - - from modelopt.torch.nas.plugins.megatron_hooks.compare_module_outputs import ( - OutputSaveHook, - save_multi_layer_outputs, - ) - - # Register hooks on all target layers - hooks = {} - for name, module in model.named_modules(): - if name.endswith('mlp.linear_fc2'): - hook = OutputSaveHook(layer_name=name) - module.register_forward_hook(hook) - hooks[name] = hook - - # Run inference/training - model(input_data) - - # Save all layer outputs - save_multi_layer_outputs(hooks, "output_unpruned.pt") - -Step 2: Compare outputs from different model variants: - - python compare_module_outputs.py \ - --reference output_unpruned.pt \ - --compare output_l2norm.pt \ - --output-json comparison_stats.json - -The saved file format: -{ - 'decoder.layers.0.mlp.linear_fc2': Tensor([steps, seq_len, batch, hidden]), - 'decoder.layers.1.mlp.linear_fc2': Tensor([...]), - ... - 'metadata': {'num_layers': N, 'num_steps': M, 'layer_names': [...]} -} -""" - -import argparse - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class OutputSaveHook: - """Hook to capture and save module outputs during forward pass.""" - - def __init__(self, layer_name: str) -> None: - """Initialize the output save hook. - - Args: - layer_name: Hierarchical name of the layer (e.g., 'decoder.layers.0.mlp.linear_fc2'). - """ - self.layer_name = layer_name - self.saved_outputs: list[torch.Tensor] = [] - - def __call__( - self, - module: nn.Module, - args: tuple[torch.Tensor, ...], - output: torch.Tensor | tuple[torch.Tensor, ...], - ) -> None: - """Capture and save module output during forward pass. - - Args: - module: The PyTorch module being hooked. - args: Input arguments to the module's forward pass. - output: Output tensor(s) from the module's forward pass. - """ - # Handle tuple outputs (e.g., output, bias) - out = output[0] if isinstance(output, tuple) else output - self.saved_outputs.append(out.detach().cpu()) - - def get_outputs_list(self) -> list[torch.Tensor]: - """Return saved outputs as a list.""" - return self.saved_outputs - - -def save_multi_layer_outputs(hooks: dict[str, OutputSaveHook], path: str) -> None: - """Save outputs from multiple layers to a single file. - - Args: - hooks: Dictionary mapping layer names to their hooks. - path: Path to save the outputs. - """ - output_dict = {name: hook.get_outputs_list() for name, hook in hooks.items()} - - # Add metadata - output_dict["metadata"] = { - "num_layers": len(hooks), - # Number of forward passes (generation steps) - all hooks have same count, so use first hook - "num_steps": len(next(iter(hooks.values())).saved_outputs) if hooks else 0, - "layer_names": list(hooks.keys()), - } - - torch.save(output_dict, path) - print(f"\nSaved outputs from {len(hooks)} layers to {path}") - for name, data in output_dict.items(): - if name != "metadata": - print(f" {name}: list of {len(data)} tensors") - - -def compute_rmse(tensor1: torch.Tensor, tensor2: torch.Tensor) -> float: - """Compute Root Mean Square Error between two tensors.""" - mse = torch.mean((tensor1 - tensor2) ** 2) - rmse = torch.sqrt(mse) - return rmse.item() - - -def compute_cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor) -> dict: - """Compute average cosine similarity between two tensors.""" - # Flatten to 2D for cosine similarity computation - t1_flat = tensor1.reshape(-1, tensor1.shape[-1]) - t2_flat = tensor2.reshape(-1, tensor2.shape[-1]) - - # Compute cosine similarity per position - cos_sim = F.cosine_similarity(t1_flat, t2_flat, dim=-1) - - return { - "mean": cos_sim.mean().item(), - "min": cos_sim.min().item(), - "max": cos_sim.max().item(), - "std": cos_sim.std().item(), - } - - -def main(): - """Compare module output tensors from different model variants.""" - parser = argparse.ArgumentParser( - description="Compare module output tensors from different model variants", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__, - ) - parser.add_argument( - "--reference", - type=str, - required=True, - help="Path to reference output tensor (e.g., unpruned model)", - ) - parser.add_argument( - "--compare", - type=str, - required=True, - help="Path to output tensor to compare against reference", - ) - parser.add_argument( - "--output-json", - type=str, - default=None, - help="Path to save comparison statistics as JSON", - ) - args = parser.parse_args() - - # Load reference data - print(f"\nLoading reference: {args.reference}") - ref_data = torch.load(args.reference, map_location="cpu") - - # Load comparison data - print(f"Loading compare: {args.compare}") - comp_data = torch.load(args.compare, map_location="cpu") - - # Compare multi-layer outputs - compare_multi_layer(ref_data, comp_data, args.output_json) - - -def compute_layer_metrics(ref_data: list, comp_data: list) -> dict: - """Compute RMSE and cosine similarity for a layer's outputs. - - Args: - ref_data: List of reference tensors. - comp_data: List of comparison tensors. - - Returns: - Dictionary with metrics. - - Raises: - ValueError: If lengths don't match or tensor shapes don't match. - """ - if len(ref_data) != len(comp_data): - raise ValueError( - f"Length mismatch: reference has {len(ref_data)} samples, compare has {len(comp_data)}" - ) - - rmse_values = [] - cos_sim_values = [] - - for ref_tensor, comp_tensor in zip(ref_data, comp_data): - if ref_tensor.shape != comp_tensor.shape: - raise ValueError( - f"Shape mismatch at index {len(rmse_values)}: " - f"reference {ref_tensor.shape} vs compare {comp_tensor.shape}" - ) - rmse_values.append(compute_rmse(ref_tensor, comp_tensor)) - cos_sim = compute_cosine_similarity(ref_tensor, comp_tensor) - cos_sim_values.append(cos_sim["mean"]) - - return { - "rmse": sum(rmse_values) / len(rmse_values), - "cosine_sim": { - "mean": sum(cos_sim_values) / len(cos_sim_values), - "min": min(cos_sim_values), - "max": max(cos_sim_values), - "std": torch.tensor(cos_sim_values).std().item() if len(cos_sim_values) > 1 else 0.0, - }, - "num_samples": len(rmse_values), - } - - -def compare_multi_layer(ref_data: dict, comp_data: dict, output_json: str | None = None): - """Compare multi-layer outputs.""" - import json - - ref_layers = [k for k in ref_data if k != "metadata"] - comp_layers = [k for k in comp_data if k != "metadata"] - - if set(ref_layers) != set(comp_layers): - print("\nERROR: Layer mismatch!") - print(f"Reference layers: {ref_layers}") - print(f"Compare layers: {comp_layers}") - return - - results = {"aggregated": {"rmse": [], "cosine_sim_mean": []}, "per_layer": {}} - - # Per-layer comparison - for layer_name in sorted(ref_layers): - ref_layer_data = ref_data[layer_name] - comp_layer_data = comp_data[layer_name] - - metrics = compute_layer_metrics(ref_layer_data, comp_layer_data) - - results["per_layer"][layer_name] = metrics - results["aggregated"]["rmse"].append(metrics["rmse"]) - results["aggregated"]["cosine_sim_mean"].append(metrics["cosine_sim"]["mean"]) - - # Aggregated statistics - if results["aggregated"]["rmse"]: - rmse_array = torch.tensor(results["aggregated"]["rmse"]) - cos_sim_array = torch.tensor(results["aggregated"]["cosine_sim_mean"]) - - results["aggregated"]["rmse_stats"] = { - "mean": rmse_array.mean().item(), - "std": rmse_array.std().item(), - "min": rmse_array.min().item(), - "max": rmse_array.max().item(), - } - results["aggregated"]["cosine_sim_stats"] = { - "mean": cos_sim_array.mean().item(), - "std": cos_sim_array.std().item(), - "min": cos_sim_array.min().item(), - "max": cos_sim_array.max().item(), - } - results["aggregated"]["num_steps"] = ref_data.get("metadata", {}).get("num_steps", None) - results["aggregated"]["num_layers"] = len(rmse_array) - - # Save to JSON if requested - if output_json: - # Remove raw lists for JSON serialization - results["aggregated"].pop("rmse", None) - results["aggregated"].pop("cosine_sim_mean", None) - - with open(output_json, "w") as f: - json.dump(results, f, indent=2) - print(f"Saved comparison results to {output_json}") - - -if __name__ == "__main__": - main() diff --git a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py deleted file mode 100644 index d792ff8941..0000000000 --- a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Megatron-specific hooks with tensor parallelism support.""" - -import torch -from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region - -from .base_hooks import L2NormHook - -__all__ = ["MegatronL2NormHook"] - - -class MegatronL2NormHook(L2NormHook): - """L2NormHook with tensor parallelism support for Megatron models. - - Extends L2NormHook to gather activations across all tensor parallel regions - before computing importance scores. - """ - - def _get_input_tensor(self, args: tuple[torch.Tensor, ...]) -> torch.Tensor: - """Gather input tensor from all TP regions.""" - # Gather input [seq_len, batch_size, hidden_size] over all TP regions - # NOTE: This is not used at the moment since we restrict to TP=1 - return gather_from_tensor_model_parallel_region(args[0]).detach() diff --git a/modelopt/torch/prune/importance_hooks/__init__.py b/modelopt/torch/prune/importance_hooks/__init__.py index 3bf30c2a46..1e86ddcf65 100644 --- a/modelopt/torch/prune/importance_hooks/__init__.py +++ b/modelopt/torch/prune/importance_hooks/__init__.py @@ -18,6 +18,7 @@ from .base_hooks import * from .base_hooks_analysis import * +from .expert_removal_hooks import * with import_plugin("megatron_hooks"): from .plugins.megatron_hooks import * diff --git a/modelopt/torch/prune/importance_hooks/base_hooks.py b/modelopt/torch/prune/importance_hooks/base_hooks.py index 248e6ec108..a28908d4b6 100644 --- a/modelopt/torch/prune/importance_hooks/base_hooks.py +++ b/modelopt/torch/prune/importance_hooks/base_hooks.py @@ -149,7 +149,8 @@ def dump_activations_logs( torch.save(activations_log, activations_log_path) if rank == 0: - args.activation_hooks_kwargs.pop("model") + if args.activation_hooks_kwargs is not None: + args.activation_hooks_kwargs.pop("model", None) json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") dist.barrier() @@ -565,9 +566,9 @@ def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): assert self.optimize_for in ["latency", "memory"] self.hidden_size = model_config.hidden_size - self.n_heads_in_group = block_config.attention.n_heads_in_group self.num_q_heads = model_config.num_attention_heads - self.num_kv_heads = self.num_q_heads // self.n_heads_in_group + self.num_kv_heads = block_config.attention.num_key_value_heads + self.n_heads_in_group = self.num_q_heads // self.num_kv_heads self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads) self.agg_kv_head_contributions = torch.zeros( diff --git a/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py b/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py new file mode 100644 index 0000000000..68eaf2e711 --- /dev/null +++ b/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""MoE expert-removal and ranked-choice importance hooks (uses Puzzletron BlockConfig).""" + +from abc import ABC, abstractmethod + +import torch +from torch import nn + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 + +from .base_hooks import ForwardHook + +__all__ = [ + "NemotronHRemoveExpertsIndependentHook", + "Qwen3VLRemoveExpertsIndependentHook", + "RankedChoiceVotingHook", + "RankedChoiceVotingHookNemotronH", + "RemoveExpertsIndependentHook", +] + + +class RemoveExpertsIndependentHook(ForwardHook, ABC): + """Base hook for measuring expert importance in Mixture-of-Experts models. + + This hook measures how much removing each expert affects the model output + by comparing outputs with and without each expert. + """ + + def __init__(self, moe: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + moe: The MoE module to analyze + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.moe = moe + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.num_local_experts = block_config.ffn.moe.num_local_experts + self.num_experts_per_tok = block_config.ffn.moe.num_experts_per_tok + # tensor of zeros of size num experts + self.diffs = ["mse", "cosine"] + some_param = next(self.moe.parameters()) + self.diffs = { + k: torch.zeros( + size=(self.num_local_experts,), dtype=torch.float32, device=some_param.device + ) + for k in self.diffs + } + self.call_count = 0 + + @abstractmethod + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for measuring expert importance. + + This method is called twice per forward pass: + 1. First call (router_logits=None): Compute original routing and expert outputs + 2. Second call (router_logits provided): Re-run with modified logits (expert disabled) + + Args: + hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) + router_logits: Optional pre-computed router logits. If None, compute from hidden_states. + + Returns: + tuple of (router_logits, routed_experts): + - router_logits: Shape (num_tokens, num_local_experts) + - routed_experts: Shape (num_tokens, hidden_dim) + """ + raise NotImplementedError + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that measures expert importance.""" + hidden_states = args[0] + router_logits, original_routed_out = self.get_router_logits_and_routed_experts( + hidden_states + ) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + original_routed_out = original_routed_out.view(-1, original_routed_out.shape[-1]) + + _, router_indices = torch.topk(router_logits, self.num_experts_per_tok, dim=-1) + self.call_count += 1 + + for i_expert in range(self.num_local_experts): + expert_mask = router_indices == i_expert + is_token_routed_to_this_expert = expert_mask.any(dim=-1) + + num_tokens_displaced = is_token_routed_to_this_expert.sum() + if num_tokens_displaced == 0: + continue + num_total_tokens = is_token_routed_to_this_expert.numel() + + relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] + + router_logits_without_i = router_logits.clone() + router_logits_without_i[..., i_expert] = -float("inf") # disable expert i + router_logits_without_i = router_logits_without_i[is_token_routed_to_this_expert, :] + _, routed_out_without_i = self.get_router_logits_and_routed_experts( + relevant_hidden_states, router_logits_without_i + ) + + relevant_tokens_original_out = original_routed_out[is_token_routed_to_this_expert, :] + self.diffs["mse"][i_expert] += ( + nn.functional.mse_loss( + relevant_tokens_original_out, routed_out_without_i, reduction="mean" + ) + * num_tokens_displaced + / num_total_tokens + ) + self.diffs["cosine"][i_expert] += ( + -nn.functional.cosine_similarity( + relevant_tokens_original_out, routed_out_without_i, dim=-1 + ).mean() + * num_tokens_displaced + / num_total_tokens + ) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format.""" + expert_ranks_mse = torch.argsort(self.diffs["mse"]) + expert_ranks_cosine = torch.argsort(self.diffs["cosine"]) + return { + "expert_ranks_mse": expert_ranks_mse.cpu(), + "expert_ranks_cosine": expert_ranks_cosine.cpu(), + "cosine_diffs": (self.diffs["cosine"] / self.call_count).cpu(), + "mse_diffs": (self.diffs["mse"] / self.call_count).cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert importance scores.""" + return self.diffs["mse"] + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "diffs_mse": self.diffs["mse"].cpu(), + "diffs_cosine": self.diffs["cosine"].cpu(), + "call_count": self.call_count, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.diffs["mse"] = state_dict["diffs_mse"].to(self.diffs["mse"].device) + self.diffs["cosine"] = state_dict["diffs_cosine"].to(self.diffs["cosine"].device) + self.call_count = state_dict["call_count"] + + +class NemotronHRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for NemotronH models.""" + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for NemotronH MoE. + + Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts. + """ + orig_shape = hidden_states.shape + # NemotronHMOE.gate forward, copied to extract router_logits + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if router_logits is None: + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), self.moe.gate.weight.type(torch.float32) + ) + router_logits = router_logits.sigmoid() + router_logits = router_logits + self.moe.gate.e_score_correction_bias.unsqueeze(0) + + topk_indices = self._get_topk_indices_without_correction_bias(router_logits) + topk_weights = router_logits.gather(1, topk_indices) + if self.moe.gate.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.moe.gate.routed_scaling_factor + # Routed experts forward + hidden_states = self.moe.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + return router_logits, hidden_states + + @torch.no_grad() + def _get_topk_indices_without_correction_bias(self, scores: torch.Tensor) -> torch.Tensor: + """Get topk indices without correction bias. + + Same as NemotronHMOE.gate.get_topk_indices but without adding e_score_correction_bias. + """ + group_scores = ( + scores.view( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.moe.gate.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .reshape(-1, self.moe.gate.n_routed_experts) + ) + scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.moe.gate.top_k, dim=-1, sorted=False)[1] + return topk_indices + + +class RankedChoiceVotingHook(ForwardHook): + """Hook for ranking experts using ranked choice voting algorithm. + + This hook tracks router decisions and uses ranked choice voting to determine + which experts are least important (can be pruned first). + """ + + def __init__(self, router: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + router: The router module (typically nn.Linear) + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.router_argsort: list[torch.Tensor] = [] + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.top_k = block_config.ffn.moe.num_experts_per_tok + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that records router decisions. + + Args: + module: The router module + args: Tuple with one tensor entry (B, T, I) + output: Router logits of shape (B, T, E) + """ + router_logits = output[0] if isinstance(output, tuple) else output + num_experts = router_logits.shape[-1] + router_argsort = torch.argsort(router_logits, dim=-1, descending=True) + router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu() + self.router_argsort.append(router_argsort) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format using ranked choice voting.""" + router_argsort = torch.concat(self.router_argsort, dim=0) + num_tokens, num_experts = router_argsort.shape + + expert_ranks = torch.full((num_experts,), -1) + expert_counts_at_pruning_time = {} + + expert_kept_per_iteration: list[list[int]] = [] + expert_counts_per_iteration: list[dict[int, int]] = [] + + for rank in range(num_experts): + ids, counts = router_argsort[:, : self.top_k].unique(return_counts=True) + ids = ids.tolist() + counts = counts.tolist() + expert_counts = dict(zip(ids, counts)) + + expert_kept_per_iteration.append(ids) + expert_counts_per_iteration.append(expert_counts) + + least_popular_expert, min_count = min(expert_counts.items(), key=lambda tup: tup[1]) + + expert_ranks[least_popular_expert] = rank + expert_counts_at_pruning_time[least_popular_expert] = min_count + print(f"#{rank}: router_argsort shape = {router_argsort.shape}") + router_argsort = router_argsort[router_argsort != least_popular_expert].view( + num_tokens, -1 + ) + + zero_shot_expert_counts = torch.zeros((num_experts,), dtype=torch.long) + for expert_id, expert_counts_val in expert_counts_per_iteration[0].items(): + zero_shot_expert_counts[expert_id] = expert_counts_val + + # Compute zero-shot expert ranks (double argsort converts counts to rank positions) + zero_shot_expert_ranks = torch.argsort(torch.argsort(zero_shot_expert_counts)) + + print("Done: Returning hook metadata.") + return { + "expert_ranks": expert_ranks, + "zero_shot_expert_ranks": zero_shot_expert_ranks, + "expert_counts_at_pruning_time": expert_counts_at_pruning_time, + "expert_counts_per_iteration": expert_counts_per_iteration, + "top_k": self.top_k, + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert ranks.""" + if not self.router_argsort: + return torch.tensor([]) + router_argsort = torch.concat(self.router_argsort, dim=0) + return router_argsort[:, 0].float() + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "router_argsort": [tensor.cpu().clone() for tensor in self.router_argsort], + "top_k": self.top_k, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.router_argsort = [tensor.cpu() for tensor in state_dict["router_argsort"]] + self.top_k = state_dict["top_k"] + + def get_progress_info(self) -> dict: + """Get progress information.""" + return { + "num_batches_processed": len(self.router_argsort), + "total_tokens_processed": sum(tensor.shape[0] for tensor in self.router_argsort) + if self.router_argsort + else 0, + } + + +class RankedChoiceVotingHookNemotronH(RankedChoiceVotingHook): + """Ranked choice voting hook for NemotronH models. + + In NemotronH, router_logits is an internal temporary state that never leaves + the forward() function. We reconstruct router_logits from the input hidden_states. + """ + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that reconstructs router logits from hidden states.""" + hidden_states = args[0] + hidden_states = hidden_states.view(-1, module.config.hidden_size) + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), module.weight.type(torch.float32) + ) + super().__call__(module, args, router_logits) + + +class Qwen3VLRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for Qwen3-VL models.""" + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for Qwen3-VL MoE. + + Based on Qwen3VLMoeSparseMoe forward pass. + """ + orig_shape = hidden_states.shape + + # Flatten to (num_tokens, hidden_size) for processing + hidden_states_flat = hidden_states.reshape(-1, self.moe.hidden_size) + + if router_logits is None: + router_logits = self.moe.gate(hidden_states_flat) + + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.moe.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states_flat.dtype) + router_weights = torch.zeros_like(router_logits).scatter_( + 1, router_indices, routing_weights + ) + + # Reshape hidden_states for moe.experts (expects 3D: batch, seq, hidden) + # router_weights and router_indices remain 2D (num_tokens, num_experts) + batch_size = orig_shape[0] if hidden_states.ndim == 3 else 1 + hidden_states_3d = hidden_states_flat.reshape(batch_size, -1, self.moe.hidden_size) + + routed_out = self.moe.experts(hidden_states_3d, router_weights, router_indices) + + # Return in same shape as input + routed_out = routed_out.reshape(*orig_shape) + + return router_logits, routed_out diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py index 33243c0125..ccf73f7612 100644 --- a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py @@ -21,7 +21,7 @@ import torch -from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook as ActivationsHook +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook as ActivationsHook from modelopt.torch.puzzletron.tools.logger import aprint from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock, DummyModule diff --git a/modelopt/torch/puzzletron/anymodel/README.md b/modelopt/torch/puzzletron/anymodel/README.md index 9dea9d45f9..291966eb7b 100644 --- a/modelopt/torch/puzzletron/anymodel/README.md +++ b/modelopt/torch/puzzletron/anymodel/README.md @@ -129,8 +129,8 @@ activation_hooks_kwargs: ### Adding a New Hook Class -1. **Implement the hook** in `modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`: - - Extend an existing hook base class (e.g., `RemoveExpertsIndependentHook`) +1. **Implement the hook** under `modelopt/torch/prune/importance_hooks/` (e.g. `base_hooks.py` for generic hooks, `expert_removal_hooks.py` for MoE expert removal): + - Extend an existing hook base class (e.g., `RemoveExpertsIndependentHook` in `expert_removal_hooks.py`) - Implement required methods (e.g., `get_router_logits_and_routed_experts`) 2. **Register the hook** in the appropriate pruning mixin's `supported_hooks()`: @@ -159,9 +159,9 @@ activation_hooks_kwargs: | Type | Mixin | Example Hooks | |------|-------|---------------| -| FFN intermediate | [`FFNIntermediatePruningMixIn`](../pruning/ffn_intermediate_pruning_mixin.py) | [`IterativeChannelContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py), [`IndependentChannelContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | -| Expert removal | [`ExpertRemovalPruningMixIn`](../pruning/expert_removal_pruning_mixin.py) | [`NemotronHRemoveExpertsIndependentHook`](../../../nas/plugins/megatron_hooks/base_hooks.py), [`Qwen3VLRemoveExpertsIndependentHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | -| KV heads | [`KVHeadsPruningMixIn`](../pruning/kv_heads_pruning_mixin.py) | [`IndependentKvHeadContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | +| FFN intermediate | [`FFNIntermediatePruningMixIn`](../pruning/ffn_intermediate_pruning_mixin.py) | [`IterativeChannelContributionHook`](../../prune/importance_hooks/base_hooks.py), [`IndependentChannelContributionHook`](../../prune/importance_hooks/base_hooks.py) | +| Expert removal | [`ExpertRemovalPruningMixIn`](../pruning/expert_removal_pruning_mixin.py) | [`NemotronHRemoveExpertsIndependentHook`](../../prune/importance_hooks/expert_removal_hooks.py), [`Qwen3VLRemoveExpertsIndependentHook`](../../prune/importance_hooks/expert_removal_hooks.py) | +| KV heads | [`KVHeadsPruningMixIn`](../pruning/kv_heads_pruning_mixin.py) | [`IndependentKvHeadContributionHook`](../../prune/importance_hooks/base_hooks.py) | ## Implementing `block_config_to_layer_overrides` diff --git a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py index 3c00ca212a..42c4ad8f51 100644 --- a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py @@ -19,8 +19,8 @@ import torch from transformers import PretrainedConfig -from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( - ForwardHook, +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook +from modelopt.torch.prune.importance_hooks.expert_removal_hooks import ( NemotronHRemoveExpertsIndependentHook, Qwen3VLRemoveExpertsIndependentHook, RankedChoiceVotingHook, diff --git a/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py index b3d9b88847..9b7993de1e 100644 --- a/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py @@ -20,7 +20,7 @@ import torch from transformers import PretrainedConfig -from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( +from modelopt.torch.prune.importance_hooks.base_hooks import ( ForwardHook, IndependentChannelContributionHook, IterativeChannelContributionHook, diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py index f93e4b77ab..4a6fe53a34 100644 --- a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -18,7 +18,7 @@ from transformers import PretrainedConfig -from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( +from modelopt.torch.prune.importance_hooks.base_hooks import ( ForwardHook, IndependentKvHeadContributionHook, ) diff --git a/modelopt/torch/puzzletron/pruning/pruning_mixin.py b/modelopt/torch/puzzletron/pruning/pruning_mixin.py index bcb422c4e6..21685848bf 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/pruning_mixin.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Type -from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook class LayerDescriptor: diff --git a/modelopt/torch/puzzletron/utils/checkpoint_manager.py b/modelopt/torch/puzzletron/utils/checkpoint_manager.py index 90303e2de9..a1347deaea 100644 --- a/modelopt/torch/puzzletron/utils/checkpoint_manager.py +++ b/modelopt/torch/puzzletron/utils/checkpoint_manager.py @@ -187,7 +187,7 @@ def update_progress(self, batch_idx: int, total_batches: int): # All ranks save their hook states if self.activation_hooks is not None: try: - from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook + from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook ForwardHook.save_hook_states(self.activation_hooks, self.checkpoint_dir) except Exception as e: @@ -240,7 +240,7 @@ def finalize(self): # All ranks save their final hook states if self.activation_hooks is not None: try: - from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook + from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook saved_path = ForwardHook.save_hook_states( self.activation_hooks, self.checkpoint_dir diff --git a/modelopt/torch/utils/robust_json.py b/modelopt/torch/utils/robust_json.py index c4a72fde83..23a3091637 100644 --- a/modelopt/torch/utils/robust_json.py +++ b/modelopt/torch/utils/robust_json.py @@ -55,8 +55,13 @@ def default(self, o): # User-defined function in main — fallback to just the name return o.__name__ return f"{o.__module__}.{o.__qualname__}" + if inspect.isclass(o): + return f"{o.__module__}.{o.__qualname__}" if isinstance(o, datetime.timedelta): return str(o) + # Fallback for arbitrary objects (e.g. mixins injected into Hydra configs) + if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): + return f"{o.__class__.__module__}.{o.__class__.__qualname__}" return super().default(o) diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py deleted file mode 100644 index aa73a3be19..0000000000 --- a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py +++ /dev/null @@ -1,100 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for base hooks.""" - -import torch -import torch.nn as nn - -from modelopt.torch.nas.plugins.megatron_hooks import IterativeChannelContributionHook, L2NormHook - - -def _test_iterative_channel_contribution_hook_with_shape(dim1: int, dim2: int): - """Helper function to test IterativeChannelContributionHook with given activation shape. - - Args: - dim1: First dimension of activation tensor (before in_features). - dim2: Second dimension of activation tensor (before in_features). - """ - torch.manual_seed(42) - - linear_layer = nn.Linear(in_features=6, out_features=4, bias=False) - activation_hooks_kwargs = { - "validation_full_iters": 3, - "clear_gpu_memory": False, - "calibration_method": None, - } - hook = IterativeChannelContributionHook(linear_layer, activation_hooks_kwargs) - linear_layer.register_forward_hook(hook) - - for _ in range(activation_hooks_kwargs["validation_full_iters"]): - activations = torch.randn(dim1, dim2, linear_layer.in_features) - _ = linear_layer(activations) - - results = hook.to_dict() - - # - # Assertions - # - assert results["score"].shape == (6,) - assert results["channels_importance_ascending"].shape == (6,) - - expected_scores = torch.tensor([5, 1, 3, 2, 4, 0]) - assert torch.equal(results["score"], expected_scores) - - expected_channels_asc = torch.tensor([5, 1, 3, 2, 4, 0]) - assert torch.equal(results["channels_importance_ascending"], expected_channels_asc) - - # Test that accumulate() returns the same scores as to_dict()["score"] - scores_from_accumulate = hook.accumulate() - assert torch.equal(scores_from_accumulate, expected_scores) - - -def test_iterative_channel_contribution_hook_sbi(): - """Test IterativeChannelContributionHook returns correct scores for input [seq_len, batch_size, in_features].""" - _test_iterative_channel_contribution_hook_with_shape(dim1=32, dim2=8) - - -def test_iterative_channel_contribution_hook_bsi(): - """Test IterativeChannelContributionHook returns correct scores for input [batch_size, seq_len, in_features].""" - _test_iterative_channel_contribution_hook_with_shape(dim1=8, dim2=32) - - -def test_l2_norm_hook(): - """Test L2NormHook returns correct scores after accumulating activations.""" - torch.manual_seed(42) - - linear_layer = nn.Linear(in_features=6, out_features=4, bias=False) - hook = L2NormHook(max_size=None) - linear_layer.register_forward_hook(hook) - - num_iterations = 3 - for _ in range(num_iterations): - activations = torch.randn(2, 3, linear_layer.in_features) - _ = linear_layer(activations) - - scores = hook.accumulate() - - # - # Assertions - # - assert scores.shape == (6,) - - expected_scores = torch.tensor( - [3.2030, 2.5018, 2.5272, 1.9222, 2.6204, 2.2623], dtype=torch.float32 - ) - assert torch.allclose(scores, expected_scores, atol=1e-4), ( - f"Expected scores {expected_scores}, got {scores}" - ) diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py deleted file mode 100644 index 954c6e11c7..0000000000 --- a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for base hooks analysis tools.""" - -import pytest -import torch -import torch.nn as nn - -from modelopt.torch.nas.plugins.megatron_hooks import ( - IndependentChannelContributionHook, - IterativeChannelContributionHook, - L2NormHook, - evaluate_importance_scores, -) - - -def test_evaluate_importance_scores_basic(): - """Test basic functionality of importance score evaluation with synthetic scores.""" - torch.manual_seed(42) - - # Create a simple linear layer (same dimensions as other tests for comparability) - layer = nn.Linear(in_features=50, out_features=30, bias=False) - - # Create synthetic hook that generates sequential importance scores - hook = SyntheticImportanceHook(num_features=50) - - # Use shared helper to run evaluation - metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) - - print(f"[SyntheticImportanceHook] Metrics: {metrics}") - - # Check values with deterministic seed - assert metrics["num_pruned"] == 20 # 40% of 50 = 20 - assert metrics["rmse"] == pytest.approx(0.3689444, rel=1e-5) - assert metrics["cosine_similarity"] == pytest.approx(0.77117118, rel=1e-5) - - -def test_evaluate_importance_scores_with_l2_norm_hook(): - """Test evaluate_importance_scores with L2NormHook.""" - torch.manual_seed(42) - - # Create layer and hook - layer = nn.Linear(in_features=50, out_features=30, bias=False) - hook = L2NormHook(max_size=None) - - # Run evaluation - metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) - - print(f"[L2NormHook] Metrics: {metrics}") - - # L2NormHook specific assertions - assert metrics["num_pruned"] == 20 # 40% of 50 = 20 - assert metrics["rmse"] == pytest.approx(0.3616334, rel=1e-5) - assert metrics["cosine_similarity"] == pytest.approx(0.7814186, rel=1e-5) - - -def test_evaluate_importance_scores_with_iterative_channel_contribution_hook(): - """Test evaluate_importance_scores with IterativeChannelContributionHook.""" - torch.manual_seed(42) - - # Create layer and hook - layer = nn.Linear(in_features=50, out_features=30, bias=False) - activation_hooks_kwargs = { - "validation_full_iters": 1000, - "clear_gpu_memory": False, - "calibration_method": None, - } - hook = IterativeChannelContributionHook(layer, activation_hooks_kwargs) - - # Run evaluation - metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) - - print(f"[IterativeChannelContributionHook] Metrics: {metrics}") - - # Iterative channel contribution hook specific assertions - assert metrics["num_pruned"] == 20 # 40% of 50 = 20 - assert metrics["rmse"] == pytest.approx(0.339014, rel=1e-5) - assert metrics["cosine_similarity"] == pytest.approx(0.8110392, rel=1e-5) - - -def test_evaluate_importance_scores_with_independent_channel_contribution_hook(): - """Test evaluate_importance_scores with IndependentChannelContributionHook.""" - torch.manual_seed(42) - - # Create layer and hook - layer = nn.Linear(in_features=50, out_features=30, bias=False) - hook = IndependentChannelContributionHook(layer) - - # Run evaluation - metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) - - print(f"[IndependentChannelContributionHook] Metrics: {metrics}") - - # Independent channel contribution hook specific assertions - assert metrics["num_pruned"] == 20 # 40% of 50 = 20 - assert metrics["rmse"] == pytest.approx(0.3385471, rel=1e-5) - assert metrics["cosine_similarity"] == pytest.approx(0.8116209, rel=1e-5) - - -def _run_hook_and_evaluate( - layer: nn.Linear, - hook, - num_iterations: int, - prune_ratio: float, -) -> dict: - """Shared helper to run hook, collect scores, and evaluate. - - Args: - layer: Linear layer to test - hook: Hook instance (already created) - num_iterations: Number of forward passes - prune_ratio: Fraction of channels to prune - - Returns: - Dictionary with evaluation metrics - """ - handle = layer.register_forward_hook(hook) # Store the handle - - # Run forward passes - all_activations = [] - for _ in range(num_iterations): - activations = torch.randn(16, 8, layer.in_features) # seq=16, batch=8, in_features=50 - all_activations.append(activations) - _ = layer(activations) - - # Get importance scores from hook - importance_scores = hook.accumulate() - - # Remove the hook before evaluation to avoid triggering it again - handle.remove() - - # Evaluate the importance scores by simulating pruning on all collected activations - # Pass the list of activations to compute averaged metrics across batches - metrics = evaluate_importance_scores( - layer, - all_activations, # List of activation batches - importance_scores, - prune_ratio=prune_ratio, - ) - - return metrics - - -class SyntheticImportanceHook: - """Synthetic hook that generates sequential importance scores for testing. - - This is a simple mock hook that doesn't compute real importance, - just returns torch.arange(num_features) to test the evaluation pipeline. - """ - - def __init__(self, num_features: int): - """Initialize with the number of features.""" - self.num_features = num_features - - def __call__(self, module, args, output): - """Hook callback - does nothing for synthetic hook.""" - - def accumulate(self) -> torch.Tensor: - """Return synthetic importance scores: [0, 1, 2, ..., num_features-1].""" - return torch.arange(self.num_features, dtype=torch.float32) diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml index 81c5f35ba5..bc1124617e 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml @@ -9,7 +9,7 @@ pruning_mixin: _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_model_descriptor.Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor target_name: "mlp" -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.Qwen3VLRemoveExpertsIndependentHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.Qwen3VLRemoveExpertsIndependentHook} activation_hooks_kwargs: # num_experts_to_keep must be >= num_experts_per_tok (can't route to more experts than exist) diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml index 4c2335becf..ae20b6d7d2 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml @@ -9,7 +9,7 @@ pruning_mixin: _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor.NemotronHExpertRemovalLayerDescriptor target_name: "mixer" -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.NemotronHRemoveExpertsIndependentHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.NemotronHRemoveExpertsIndependentHook} activation_hooks_kwargs: # Additional kwargs to pass to the hook init num_experts_to_keep_list: [96, 64, 32, 16, 8] # num_experts in teacher is 128 diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml index cb1147d86b..abc501287d 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml @@ -7,7 +7,7 @@ pruning_mixin: layer_descriptor: _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activation_hooks_kwargs: # Additional kwargs to pass to the hook init intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml index 4656f1df42..5a4761886f 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml @@ -9,7 +9,7 @@ pruning_mixin: layer_descriptor: _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor target_name: "mlp.router" -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.RankedChoiceVotingHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.RankedChoiceVotingHook} activation_hooks_kwargs: # Additional kwargs to pass to the hook init num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml index 7306b6e379..0dadc20134 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml @@ -9,7 +9,7 @@ pruning_mixin: layer_descriptor: _target_: ??? -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IndependentKvHeadContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IndependentKvHeadContributionHook} activation_hooks_kwargs: method: independent_kv_head_contribution optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml index 7e19afbbce..c1c951984f 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml @@ -9,7 +9,7 @@ pruning_mixin: layer_descriptor: _target_: ??? -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activation_hooks_kwargs: method: iterative target_layer: "mlp.down_proj" diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index e39a7b22ed..de68ba500a 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -38,10 +38,6 @@ SEED = 1234 -def _assert_approx(actual, expected, abs=1e-3): - assert actual == pytest.approx(expected, abs=abs), f"{actual=} != {expected=}" - - def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): set_seed(SEED) # Use relatively bigger model here for more accurate test for sorting @@ -181,7 +177,6 @@ def _get_model(initialize_megatron=True): return model model = _get_model() - sd = model.state_dict() def forward_loop(m): @@ -218,48 +213,6 @@ def forward_loop(m): assert pruning_scores["layer_scores"] assert pruning_scores["local_activations"] - # TODO: Simplify it: this unit test is too long, - # hard to read (the same set of assertions across different test cases with if-else). - - assert len(pruning_scores["activations_per_rank"]) == size - activations = pruning_scores["activations_per_rank"][rank] - - # Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4) - if size == 1 and pruned_ffn_div == 4: - # Layer scores - _assert_approx(pruning_scores["layer_scores"], {1: 0.028923, 2: 0.046508}) - - # Validate decoder.layers.0.mlp activations - mlp_0_acts = activations["decoder.layers.0.mlp"] - _assert_approx(mlp_0_acts.min().item(), 0.000026) - _assert_approx(mlp_0_acts.max().item(), 0.000729) - _assert_approx(mlp_0_acts.mean().item(), 0.000201) - - # Validate decoder.layers.1.mlp activations - mlp_1_acts = activations["decoder.layers.1.mlp"] - _assert_approx(mlp_1_acts.min().item(), 0.000022) - _assert_approx(mlp_1_acts.max().item(), 0.000762) - _assert_approx(mlp_1_acts.mean().item(), 0.000162) - - # Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2) - elif size == 1 and pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1: - # Layer scores - _assert_approx(pruning_scores["layer_scores"], {1: 0.028056, 2: 0.038353}) - - # Validate decoder.layers.0.self_attention activations - attn_0_acts = activations["decoder.layers.0.self_attention"] - assert attn_0_acts.shape == torch.Size([hidden_size]) - _assert_approx(attn_0_acts.min().item(), 0.010091) - _assert_approx(attn_0_acts.max().item(), 0.023826) - _assert_approx(attn_0_acts.mean().item(), 0.014548) - - # Validate decoder.layers.1.self_attention activations - attn_1_acts = activations["decoder.layers.1.self_attention"] - assert attn_1_acts.shape == torch.Size([hidden_size]) - _assert_approx(attn_1_acts.min().item(), 0.009982) - _assert_approx(attn_1_acts.max().item(), 0.035644) - _assert_approx(attn_1_acts.mean().item(), 0.020140) - # Assert weights are pruned correctly for layer in model.decoder.layers: assert layer.mlp.linear_fc1.weight.shape == ( From 7e15fddf7b10558a10920f9b32cfbc119d239af3 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 30 Mar 2026 01:15:02 -0700 Subject: [PATCH 53/62] Revert CICD and other config changes Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/workflows/example_tests.yml | 96 ++++++- .github/workflows/gpu_tests.yml | 18 +- .pre-commit-config.yaml | 13 +- .../decilm/deci_lm_hf_code/block_config.py | 4 +- pyproject.toml | 1 - tests/gpu/torch/conftest.py | 59 ----- tox.ini | 10 - uv.lock | 244 +----------------- 8 files changed, 111 insertions(+), 334 deletions(-) delete mode 100644 tests/gpu/torch/conftest.py diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index 693fe0ec45..848b3d326d 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -56,6 +56,68 @@ jobs: match_pattern: "^DCO$|^linux$" # Wait for DCO and Unit tests / linux to pass delay: 300s + ##### PyTorch Example Tests (speculative_decoding requires 26.01 image) ##### + torch-pr: + needs: [check-file-changes, wait-checks] + if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' + strategy: &torch_strategy + fail-fast: false + matrix: + example: [llm_distill, llm_qat, llm_sparsity] + include: + - example: speculative_decoding + docker_image: "26.01" + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3" + example: ${{ matrix.example }} + timeout_minutes: 30 + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-h100-latest-1 + + torch-non-pr: + if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + strategy: *torch_strategy + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3" + example: ${{ matrix.example }} + timeout_minutes: 30 + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-rtxpro6000-latest-2 + + ##### TensorRT-LLM Example Tests ##### + trtllm-pr: + needs: [check-file-changes, wait-checks] + if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' + strategy: + fail-fast: false + matrix: + example: [llm_ptq, vlm_ptq] + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc5" + example: ${{ matrix.example }} + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-rtxpro6000-latest-1 + + trtllm-non-pr: + if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + strategy: + fail-fast: false + matrix: + example: [llm_autodeploy, llm_eval, llm_ptq, vlm_ptq] + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc5" + example: ${{ matrix.example }} + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-rtxpro6000-latest-2 + ##### NeMo Example Tests ##### nemo-pr: needs: [check-file-changes, wait-checks] @@ -85,17 +147,47 @@ jobs: pip_install_extras: "[hf,puzzletron,dev-test]" runner: linux-amd64-gpu-rtxpro6000-latest-2 + ##### ONNX/TensorRT Example Tests ##### + onnx-pr: + needs: [check-file-changes, wait-checks] + if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' + strategy: &onnx_strategy + fail-fast: false + matrix: + example: [diffusers, torch_onnx] + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/tensorrt:26.02-py3" + example: ${{ matrix.example }} + pip_install_extras: "[all,dev-test]" + runner: linux-amd64-gpu-l4-latest-1 + + onnx-non-pr: + if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + strategy: *onnx_strategy + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/tensorrt:26.02-py3" + example: ${{ matrix.example }} + pip_install_extras: "[all,dev-test]" + runner: linux-amd64-gpu-rtxpro6000-latest-2 + ##### Required Check for PR ##### example-pr-required-check: # Run even if example tests are skipped if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }} - needs: [check-file-changes, nemo-pr] + needs: [check-file-changes, torch-pr, trtllm-pr, nemo-pr, onnx-pr] runs-on: ubuntu-latest steps: - name: Required GPU tests did not succeed if: | needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && ( - needs.nemo-pr.result != 'success' + needs.torch-pr.result != 'success' || + needs.trtllm-pr.result != 'success' || + needs.nemo-pr.result != 'success' || + needs.onnx-pr.result != 'success' )) run: exit 1 diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index 0f8e484163..542e948909 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -62,16 +62,16 @@ jobs: fail-fast: false matrix: include: - - example: gpu-puzzletron - timeout: 30 + - example: gpu + timeout: 45 container_image: pytorch:26.01-py3 - # - example: gpu-megatron - # timeout: 45 - # container_image: pytorch:26.01-py3 - # - example: gpu-trtllm - # timeout: 30 - # container_image: tensorrt-llm/release:1.3.0rc5 - runs-on: linux-amd64-gpu-rtxpro6000-latest-2 + - example: gpu-megatron + timeout: 45 + container_image: pytorch:26.01-py3 + - example: gpu-trtllm + timeout: 30 + container_image: tensorrt-llm/release:1.3.0rc5 + runs-on: linux-amd64-gpu-rtxpro6000-latest-1 timeout-minutes: ${{ matrix.timeout }} container: &gpu_container image: nvcr.io/nvidia/${{ matrix.container_image }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 554655f9dc..cd7f922fbb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,18 +25,9 @@ repos: hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - # See: commit hooks modifies block_config.py leading to test_puzzletron.py failing (#25) · Issues · omniml / modelopt · GitLab - exclude: > - (?x)^( - ^examples/specdec_bench/specdec_bench/datasets/speed\.py$| - modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config\.py - )$ + exclude: ^examples/specdec_bench/specdec_bench/datasets/speed\.py$ - id: ruff-format - exclude: > - (?x)^( - ^examples/specdec_bench/specdec_bench/datasets/speed\.py$| - modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config\.py - )$ + exclude: ^examples/specdec_bench/specdec_bench/datasets/speed\.py$ - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.17.1 diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py index a7212516a7..fb630335c6 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py @@ -222,7 +222,9 @@ def __post_init__(self): elif self.is_moe: self._force_setattr("intermediate_size", None) else: - assert self.intermediate_size is not None, "Intermediate size must be provided for an FFN block" + assert self.intermediate_size is not None, ( + "Intermediate size must be provided for an FFN block" + ) def to_blockconfig(self) -> "BlockConfig": return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self) diff --git a/pyproject.toml b/pyproject.toml index d7df6eaecd..d88ed5b807 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,6 @@ puzzletron = [ # Dependedencies for modelopt.torch.puzzletron subpackage "immutabledict", "lru-dict", "mip", - "omegaconf==2.3.0", "pandas", "typeguard", ] diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py deleted file mode 100644 index a38322d141..0000000000 --- a/tests/gpu/torch/conftest.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch -import torch.distributed as dist -from _test_utils.torch.distributed.utils import init_process - -import modelopt.torch.opt as mto - - -@pytest.fixture -def distributed_setup_size_1(): - init_process(rank=0, size=1, backend="nccl") - yield - dist.destroy_process_group() - - -@pytest.fixture -def need_2_gpus(): - if torch.cuda.device_count() < 2: - pytest.skip("Need at least 2 GPUs to run this test") - - -@pytest.fixture -def need_8_gpus(): - if torch.cuda.device_count() < 8: - pytest.skip("Need at least 8 GPUs to run this test") - - -@pytest.fixture -def need_4_gpus(): - if torch.cuda.device_count() < 4: - pytest.skip("Need at least 4 GPUs to run this test") - - -@pytest.fixture(scope="module") -def set_torch_dtype(request): - orig_dtype = torch.get_default_dtype() - torch.set_default_dtype(request.param) - yield - torch.set_default_dtype(orig_dtype) - - -@pytest.fixture(scope="session", autouse=True) -def enable_hf_checkpointing(): - mto.enable_huggingface_checkpointing() diff --git a/tox.ini b/tox.ini index ea5ba1ac86..80299d814d 100644 --- a/tox.ini +++ b/tox.ini @@ -70,16 +70,6 @@ commands = # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" python -m pytest tests/gpu -[testenv:cuda13-gpu-puzzletron] -commands_pre = - # Install deps here so that it gets installed even in --current-env - pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git - pip install --no-build-isolation git+https://github.com/Dao-AILab/causal-conv1d.git - pip install -e .[hf,puzzletron,dev-test] -commands = - # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" - python -m pytest tests/gpu/torch/puzzletron - [testenv:cuda13-gpu-megatron] commands_pre = # Install deps here so that it gets installed even in --current-env diff --git a/uv.lock b/uv.lock index 5e49c2e8cb..d890e361cb 100644 --- a/uv.lock +++ b/uv.lock @@ -257,18 +257,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/06/f3/39cf3367b8107baa44f861dc802cbf16263c945b62d8265d36034fc07bea/cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114", size = 13918, upload-time = "2026-03-09T20:51:27.33Z" }, ] -[[package]] -name = "cbcbox" -version = "2.924" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/86/acd4af8ab4b00dbffac101bb2c87f2d84bdf3e2f3fa79171582c91770178/cbcbox-2.924-py3-none-macosx_15_0_arm64.whl", hash = "sha256:5ba1be40e761c47cbf6e94783f89be60e280d62d049f8f11c69c50fe5719a0ad", size = 30115522, upload-time = "2026-03-12T02:44:27.26Z" }, - { url = "https://files.pythonhosted.org/packages/2f/1c/3d528eb20a94db16c01e14d1f3e307fad73210c35fbd1dfd41b4214b4d64/cbcbox-2.924-py3-none-macosx_15_0_x86_64.whl", hash = "sha256:7f05a5c81c39e94ba32f5230ce9a4c93a899329941fa06f290a538af6121c4cc", size = 59928608, upload-time = "2026-03-12T02:44:30.814Z" }, - { url = "https://files.pythonhosted.org/packages/18/4b/a6ea7c4f600c071a4f6e653054a2172d315269229610d0efb9afcf67af77/cbcbox-2.924-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6842d5a646d650bad77ceddffd89f09a64d111f047b4744ce2e01361f352efed", size = 35904686, upload-time = "2026-03-12T02:44:34.347Z" }, - { url = "https://files.pythonhosted.org/packages/eb/96/9c3d681116a9df29273b48454aec4961b715e6ea01cbeea6ac3636c106d1/cbcbox-2.924-py3-none-manylinux2014_x86_64.whl", hash = "sha256:283c37212a63d2af55ed618653b45d3e66561c8f58e2c5542e3a323e400c6dc3", size = 72713045, upload-time = "2026-03-12T02:44:38.623Z" }, - { url = "https://files.pythonhosted.org/packages/de/29/59500c00eed48de52242889ac94fc6e9ad5fd780646d78847913274fe9e2/cbcbox-2.924-py3-none-win_amd64.whl", hash = "sha256:e5d0b308f89e56bba50286506417fceee0ef085ecf0eda30ed7beb5fea3abb4f", size = 57577749, upload-time = "2026-03-12T02:44:43.133Z" }, -] - [[package]] name = "certifi" version = "2026.2.25" @@ -278,54 +266,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/3c/c17fb3ca2d9c3acff52e30b309f538586f9f5b9c9cf454f3845fc9af4881/certifi-2026.2.25-py3-none-any.whl", hash = "sha256:027692e4402ad994f1c42e52a4997a9763c646b73e4096e4d5d6db8af1d6f0fa", size = 153684, upload-time = "2026-02-25T02:54:15.766Z" }, ] -[[package]] -name = "cffi" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pycparser", marker = "implementation_name != 'PyPy'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/93/d7/516d984057745a6cd96575eea814fe1edd6646ee6efd552fb7b0921dec83/cffi-2.0.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44", size = 184283, upload-time = "2025-09-08T23:22:08.01Z" }, - { url = "https://files.pythonhosted.org/packages/9e/84/ad6a0b408daa859246f57c03efd28e5dd1b33c21737c2db84cae8c237aa5/cffi-2.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49", size = 180504, upload-time = "2025-09-08T23:22:10.637Z" }, - { url = "https://files.pythonhosted.org/packages/50/bd/b1a6362b80628111e6653c961f987faa55262b4002fcec42308cad1db680/cffi-2.0.0-cp310-cp310-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c", size = 208811, upload-time = "2025-09-08T23:22:12.267Z" }, - { url = "https://files.pythonhosted.org/packages/4f/27/6933a8b2562d7bd1fb595074cf99cc81fc3789f6a6c05cdabb46284a3188/cffi-2.0.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb", size = 216402, upload-time = "2025-09-08T23:22:13.455Z" }, - { url = "https://files.pythonhosted.org/packages/05/eb/b86f2a2645b62adcfff53b0dd97e8dfafb5c8aa864bd0d9a2c2049a0d551/cffi-2.0.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0", size = 203217, upload-time = "2025-09-08T23:22:14.596Z" }, - { url = "https://files.pythonhosted.org/packages/9f/e0/6cbe77a53acf5acc7c08cc186c9928864bd7c005f9efd0d126884858a5fe/cffi-2.0.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4", size = 203079, upload-time = "2025-09-08T23:22:15.769Z" }, - { url = "https://files.pythonhosted.org/packages/98/29/9b366e70e243eb3d14a5cb488dfd3a0b6b2f1fb001a203f653b93ccfac88/cffi-2.0.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453", size = 216475, upload-time = "2025-09-08T23:22:17.427Z" }, - { url = "https://files.pythonhosted.org/packages/21/7a/13b24e70d2f90a322f2900c5d8e1f14fa7e2a6b3332b7309ba7b2ba51a5a/cffi-2.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495", size = 218829, upload-time = "2025-09-08T23:22:19.069Z" }, - { url = "https://files.pythonhosted.org/packages/60/99/c9dc110974c59cc981b1f5b66e1d8af8af764e00f0293266824d9c4254bc/cffi-2.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5", size = 211211, upload-time = "2025-09-08T23:22:20.588Z" }, - { url = "https://files.pythonhosted.org/packages/49/72/ff2d12dbf21aca1b32a40ed792ee6b40f6dc3a9cf1644bd7ef6e95e0ac5e/cffi-2.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb", size = 218036, upload-time = "2025-09-08T23:22:22.143Z" }, - { url = "https://files.pythonhosted.org/packages/e2/cc/027d7fb82e58c48ea717149b03bcadcbdc293553edb283af792bd4bcbb3f/cffi-2.0.0-cp310-cp310-win32.whl", hash = "sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a", size = 172184, upload-time = "2025-09-08T23:22:23.328Z" }, - { url = "https://files.pythonhosted.org/packages/33/fa/072dd15ae27fbb4e06b437eb6e944e75b068deb09e2a2826039e49ee2045/cffi-2.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739", size = 182790, upload-time = "2025-09-08T23:22:24.752Z" }, - { url = "https://files.pythonhosted.org/packages/12/4a/3dfd5f7850cbf0d06dc84ba9aa00db766b52ca38d8b86e3a38314d52498c/cffi-2.0.0-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe", size = 184344, upload-time = "2025-09-08T23:22:26.456Z" }, - { url = "https://files.pythonhosted.org/packages/4f/8b/f0e4c441227ba756aafbe78f117485b25bb26b1c059d01f137fa6d14896b/cffi-2.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c", size = 180560, upload-time = "2025-09-08T23:22:28.197Z" }, - { url = "https://files.pythonhosted.org/packages/b1/b7/1200d354378ef52ec227395d95c2576330fd22a869f7a70e88e1447eb234/cffi-2.0.0-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92", size = 209613, upload-time = "2025-09-08T23:22:29.475Z" }, - { url = "https://files.pythonhosted.org/packages/b8/56/6033f5e86e8cc9bb629f0077ba71679508bdf54a9a5e112a3c0b91870332/cffi-2.0.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93", size = 216476, upload-time = "2025-09-08T23:22:31.063Z" }, - { url = "https://files.pythonhosted.org/packages/dc/7f/55fecd70f7ece178db2f26128ec41430d8720f2d12ca97bf8f0a628207d5/cffi-2.0.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5", size = 203374, upload-time = "2025-09-08T23:22:32.507Z" }, - { url = "https://files.pythonhosted.org/packages/84/ef/a7b77c8bdc0f77adc3b46888f1ad54be8f3b7821697a7b89126e829e676a/cffi-2.0.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664", size = 202597, upload-time = "2025-09-08T23:22:34.132Z" }, - { url = "https://files.pythonhosted.org/packages/d7/91/500d892b2bf36529a75b77958edfcd5ad8e2ce4064ce2ecfeab2125d72d1/cffi-2.0.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26", size = 215574, upload-time = "2025-09-08T23:22:35.443Z" }, - { url = "https://files.pythonhosted.org/packages/44/64/58f6255b62b101093d5df22dcb752596066c7e89dd725e0afaed242a61be/cffi-2.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9", size = 218971, upload-time = "2025-09-08T23:22:36.805Z" }, - { url = "https://files.pythonhosted.org/packages/ab/49/fa72cebe2fd8a55fbe14956f9970fe8eb1ac59e5df042f603ef7c8ba0adc/cffi-2.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414", size = 211972, upload-time = "2025-09-08T23:22:38.436Z" }, - { url = "https://files.pythonhosted.org/packages/0b/28/dd0967a76aab36731b6ebfe64dec4e981aff7e0608f60c2d46b46982607d/cffi-2.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743", size = 217078, upload-time = "2025-09-08T23:22:39.776Z" }, - { url = "https://files.pythonhosted.org/packages/2b/c0/015b25184413d7ab0a410775fdb4a50fca20f5589b5dab1dbbfa3baad8ce/cffi-2.0.0-cp311-cp311-win32.whl", hash = "sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5", size = 172076, upload-time = "2025-09-08T23:22:40.95Z" }, - { url = "https://files.pythonhosted.org/packages/ae/8f/dc5531155e7070361eb1b7e4c1a9d896d0cb21c49f807a6c03fd63fc877e/cffi-2.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5", size = 182820, upload-time = "2025-09-08T23:22:42.463Z" }, - { url = "https://files.pythonhosted.org/packages/95/5c/1b493356429f9aecfd56bc171285a4c4ac8697f76e9bbbbb105e537853a1/cffi-2.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d", size = 177635, upload-time = "2025-09-08T23:22:43.623Z" }, - { url = "https://files.pythonhosted.org/packages/ea/47/4f61023ea636104d4f16ab488e268b93008c3d0bb76893b1b31db1f96802/cffi-2.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d", size = 185271, upload-time = "2025-09-08T23:22:44.795Z" }, - { url = "https://files.pythonhosted.org/packages/df/a2/781b623f57358e360d62cdd7a8c681f074a71d445418a776eef0aadb4ab4/cffi-2.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c", size = 181048, upload-time = "2025-09-08T23:22:45.938Z" }, - { url = "https://files.pythonhosted.org/packages/ff/df/a4f0fbd47331ceeba3d37c2e51e9dfc9722498becbeec2bd8bc856c9538a/cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe", size = 212529, upload-time = "2025-09-08T23:22:47.349Z" }, - { url = "https://files.pythonhosted.org/packages/d5/72/12b5f8d3865bf0f87cf1404d8c374e7487dcf097a1c91c436e72e6badd83/cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062", size = 220097, upload-time = "2025-09-08T23:22:48.677Z" }, - { url = "https://files.pythonhosted.org/packages/c2/95/7a135d52a50dfa7c882ab0ac17e8dc11cec9d55d2c18dda414c051c5e69e/cffi-2.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e", size = 207983, upload-time = "2025-09-08T23:22:50.06Z" }, - { url = "https://files.pythonhosted.org/packages/3a/c8/15cb9ada8895957ea171c62dc78ff3e99159ee7adb13c0123c001a2546c1/cffi-2.0.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037", size = 206519, upload-time = "2025-09-08T23:22:51.364Z" }, - { url = "https://files.pythonhosted.org/packages/78/2d/7fa73dfa841b5ac06c7b8855cfc18622132e365f5b81d02230333ff26e9e/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba", size = 219572, upload-time = "2025-09-08T23:22:52.902Z" }, - { url = "https://files.pythonhosted.org/packages/07/e0/267e57e387b4ca276b90f0434ff88b2c2241ad72b16d31836adddfd6031b/cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94", size = 222963, upload-time = "2025-09-08T23:22:54.518Z" }, - { url = "https://files.pythonhosted.org/packages/b6/75/1f2747525e06f53efbd878f4d03bac5b859cbc11c633d0fb81432d98a795/cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187", size = 221361, upload-time = "2025-09-08T23:22:55.867Z" }, - { url = "https://files.pythonhosted.org/packages/7b/2b/2b6435f76bfeb6bbf055596976da087377ede68df465419d192acf00c437/cffi-2.0.0-cp312-cp312-win32.whl", hash = "sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18", size = 172932, upload-time = "2025-09-08T23:22:57.188Z" }, - { url = "https://files.pythonhosted.org/packages/f8/ed/13bd4418627013bec4ed6e54283b1959cf6db888048c7cf4b4c3b5b36002/cffi-2.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5", size = 183557, upload-time = "2025-09-08T23:22:58.351Z" }, - { url = "https://files.pythonhosted.org/packages/95/31/9f7f93ad2f8eff1dbc1c3656d7ca5bfd8fb52c9d786b4dcf19b2d02217fa/cffi-2.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6", size = 177762, upload-time = "2025-09-08T23:22:59.668Z" }, -] - [[package]] name = "cfgv" version = "3.5.0" @@ -697,18 +637,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/a5/842ae8f0c08b61d6484b52f99a03510a3a72d23141942d216ebe81fefbce/filelock-3.25.2-py3-none-any.whl", hash = "sha256:ca8afb0da15f229774c9ad1b455ed96e85a81373065fb10446672f64444ddf70", size = 26759, upload-time = "2026-03-11T20:45:37.437Z" }, ] -[[package]] -name = "fire" -version = "0.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "termcolor" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c0/00/f8d10588d2019d6d6452653def1ee807353b21983db48550318424b5ff18/fire-0.7.1.tar.gz", hash = "sha256:3b208f05c736de98fb343310d090dcc4d8c78b2a89ea4f32b837c586270a9cbf", size = 88720, upload-time = "2025-08-16T20:20:24.175Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/4c/93d0f85318da65923e4b91c1c2ff03d8a458cbefebe3bc612a6693c7906d/fire-0.7.1-py3-none-any.whl", hash = "sha256:e43fd8a5033a9001e7e2973bab96070694b9f12f2e0ecf96d4683971b5ab1882", size = 115945, upload-time = "2025-08-16T20:20:22.87Z" }, -] - [[package]] name = "flatbuffers" version = "25.12.19" @@ -921,20 +849,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794, upload-time = "2021-09-17T21:40:39.897Z" }, ] -[[package]] -name = "hydra-core" -version = "1.3.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "antlr4-python3-runtime" }, - { name = "omegaconf" }, - { name = "packaging" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/6d/8e/07e42bc434a847154083b315779b0a81d567154504624e181caf2c71cd98/hydra-core-1.3.2.tar.gz", hash = "sha256:8a878ed67216997c3e9d88a8e72e7b4767e81af37afb4ea3334b269a4390a824", size = 3263494, upload-time = "2023-02-23T18:33:43.03Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b", size = 154547, upload-time = "2023-02-23T18:33:40.801Z" }, -] - [[package]] name = "identify" version = "2.6.18" @@ -962,15 +876,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/53/fb7122b71361a0d121b669dcf3d31244ef75badbbb724af388948de543e2/imagesize-2.0.0-py2.py3-none-any.whl", hash = "sha256:5667c5bbb57ab3f1fa4bc366f4fbc971db3d5ed011fd2715fd8001f782718d96", size = 9441, upload-time = "2026-03-03T14:18:27.892Z" }, ] -[[package]] -name = "immutabledict" -version = "4.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/e6/718471048fea0366c3e3d1df3acfd914ca66d571cdffcf6d37bbcd725708/immutabledict-4.3.1.tar.gz", hash = "sha256:f844a669106cfdc73f47b1a9da003782fb17dc955a54c80972e0d93d1c63c514", size = 7806, upload-time = "2026-02-15T10:32:34.668Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a3/ce/f9018bf69ae91b273b6391a095e7c93fa5e1617f25b6ba81ad4b20c9df10/immutabledict-4.3.1-py3-none-any.whl", hash = "sha256:c9facdc0ff30fdb8e35bd16532026cac472a549e182c94fa201b51b25e4bf7bf", size = 5000, upload-time = "2026-02-15T10:32:33.672Z" }, -] - [[package]] name = "importlib-metadata" version = "9.0.0" @@ -1064,57 +969,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bf/bf/b7802b6578ca3a6506aaac6696ac1e8de500419fee3cd288184e82a8c2aa/lief-0.17.6-cp313-cp313-win_arm64.whl", hash = "sha256:6d4eb8adce400af52cc174ac5cbe40ab10b9df5824193975d12e2d4f85b298a3", size = 3461166, upload-time = "2026-03-18T06:58:32.975Z" }, ] -[[package]] -name = "lru-dict" -version = "1.4.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/06/0a/dec86efe38b350314c49a8d39ef01ba7cf8bbbef1d177646320eedea7159/lru_dict-1.4.1.tar.gz", hash = "sha256:cc518ff2d38cc7a8ab56f9a6ae557f91e2e1524b57ed8e598e97f45a2bd708fc", size = 13439, upload-time = "2025-11-02T10:02:13.548Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a4/6c/396716746ca46fd2ac52a7a6cbd7b4cf848e5d430f431dacd209290dfa71/lru_dict-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3766e397aa6de1ca3442729bc1fa75834ab7b0a6b017e6e197d3a66b61abde59", size = 16757, upload-time = "2025-11-02T10:00:55.767Z" }, - { url = "https://files.pythonhosted.org/packages/2d/93/c163ffb71beb18f18459461658fd16c8b8c86aed858f2dc7c7e636318f61/lru_dict-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658e152d3a4ad0e1d75e6f53b1fa353779539920b38be99f4ea33d3bad41efdb", size = 11243, upload-time = "2025-11-02T10:00:56.715Z" }, - { url = "https://files.pythonhosted.org/packages/44/e3/fa96d54032531c67eeacf0ab6f56e10e05f25d382a29f6a381ac8ecf3814/lru_dict-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:98af7044b5c3d85a649e1afb8891829ff5210caf9143acc741b3e98ab1b66ff6", size = 11726, upload-time = "2025-11-02T10:00:57.377Z" }, - { url = "https://files.pythonhosted.org/packages/7a/23/bae4f32fb014fd2dc5512e9267a3b1ec34c3b55d16a2202a1193d9ae635d/lru_dict-1.4.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:906d99705b79a00b5668bdb8782ad823ccc8d26e1fc6b56327ae469a8d12e9b4", size = 29823, upload-time = "2025-11-02T10:00:58.34Z" }, - { url = "https://files.pythonhosted.org/packages/9f/3b/8c3d1e6a188ce65e0161b86dbd18f2290950baf1e9e28e4948fc123d9a67/lru_dict-1.4.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:885643fd968336d8652fddb0778184e2eeff7b7aebced6de268af6d6caef42d5", size = 30812, upload-time = "2025-11-02T10:00:59.358Z" }, - { url = "https://files.pythonhosted.org/packages/ed/11/7f061507eda944150ed959e99a3700ce6358c1241c7f697b2f1ade48646b/lru_dict-1.4.1-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:24c779334bed82f1a7eb2d1ebcba2b7aa9a1555d40a3b53e05eb6b9dfcb0609c", size = 32480, upload-time = "2025-11-02T10:01:00.141Z" }, - { url = "https://files.pythonhosted.org/packages/75/e7/94ac30d33c6f8a8eca5d7e81c0ce26fb7b79b18ea65accdcb2a652b19abc/lru_dict-1.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c6099e2ecb118dfeae4a197bfcc702ea5841bfd86f19d1b340e932d0f5c47c10", size = 30199, upload-time = "2025-11-02T10:01:01.31Z" }, - { url = "https://files.pythonhosted.org/packages/4a/81/c93ee7365db67dfb497e6218aa0395b9ec878c07c732d348bfbd651bcc95/lru_dict-1.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4e0db4f3105108598749550e639b283b07df0bb91cac3b47e86ffebcab721cc7", size = 31489, upload-time = "2025-11-02T10:01:02.363Z" }, - { url = "https://files.pythonhosted.org/packages/9f/0b/634e8b4eca2497647f802bbe1ae3f0e1e9a0de1d555cf77c022527b2682f/lru_dict-1.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e21f67ba374d1945051b547e719d44a8c7880718f67a15a03e7a12e1d12ea96b", size = 29522, upload-time = "2025-11-02T10:01:03.399Z" }, - { url = "https://files.pythonhosted.org/packages/de/cc/591b959d77cc0e0ac016f11baf26d03d566bb88a53fa9b41e157bc68bc4b/lru_dict-1.4.1-cp310-cp310-win32.whl", hash = "sha256:f309b4018dd41f33bf3bd4cc0f62421da8bcca513ea044dbb22f3cd029935012", size = 13066, upload-time = "2025-11-02T10:01:04.457Z" }, - { url = "https://files.pythonhosted.org/packages/d9/bc/c14b67fdbdb5a2a81cfb907ea8a8b0c9da5aed899f34921ebf097e22a966/lru_dict-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:e84cd1065955897de01f1fb4cbd6f87cab7706e920283bb98c27341d76dd9a8d", size = 14008, upload-time = "2025-11-02T10:01:05.421Z" }, - { url = "https://files.pythonhosted.org/packages/4c/ff/1d02bc444174f07d3ce747568989969c97dc77d0513f4c3b8b6224cb976f/lru_dict-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cc74c49cf1c26d6c28d8f6988cf0354696ca38a4f6012fa63055d2800791784b", size = 16760, upload-time = "2025-11-02T10:01:06.492Z" }, - { url = "https://files.pythonhosted.org/packages/0b/d8/e2e970272ea5fe7ba6349a5e7d0bb0fd814f5d1b88a53bc72b8c2a5e034f/lru_dict-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0158db85dfb2cd2fd2ddaa47709bdb073f814e0a8a149051b70b07e59ac83231", size = 11249, upload-time = "2025-11-02T10:01:07.261Z" }, - { url = "https://files.pythonhosted.org/packages/a5/26/860b5e60f339f8038118028388926224c8b70779e8243d68772e0e0d0ab3/lru_dict-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c8ac5cfd56e036bd8d7199626147044485fa64a163a5bde96bfa5a1c7fea2273", size = 11728, upload-time = "2025-11-02T10:01:08.185Z" }, - { url = "https://files.pythonhosted.org/packages/61/55/fc8f71953fd343ede33810b0a000b4130e03635ae09b28569e45735ded2f/lru_dict-1.4.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2eb2058cb7b329b4b72baee4cd1bb322af1feec73de79e68edb35d333c90b698", size = 30795, upload-time = "2025-11-02T10:01:08.862Z" }, - { url = "https://files.pythonhosted.org/packages/4c/26/ad549550e6a236818a91434570d38d7a93824b0410d3db1c845a53238e1f/lru_dict-1.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ffbb6f3c1e906e92d9129c14a88d81358be1e0b60195c1729b215a52e9670de", size = 31807, upload-time = "2025-11-02T10:01:09.581Z" }, - { url = "https://files.pythonhosted.org/packages/7c/39/72dae9ac0e95a8576a45e3bd62a6fc3e7dbb116794efa1337c7b450d4836/lru_dict-1.4.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:11b289d78a48a086846e46d2275707d33523f5d543475336c29c56fd5d0e65dc", size = 33437, upload-time = "2025-11-02T10:01:10.676Z" }, - { url = "https://files.pythonhosted.org/packages/a8/46/221479834703a5397fa32f07212ace38f104a31ad1af8a921cf25e053677/lru_dict-1.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3fe10c1f45712e191eecb2a69604d566c64ddfe01136fd467c890ed558c3ad40", size = 31168, upload-time = "2025-11-02T10:01:11.47Z" }, - { url = "https://files.pythonhosted.org/packages/6e/13/98d36e2522fda7f6625c15332562f81f1465161a5ae021d9b3b408f8c427/lru_dict-1.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e04820e3473bd7f55440f24c946ca4335e392d5e3e0e1e948020e94cd1954372", size = 32454, upload-time = "2025-11-02T10:01:12.522Z" }, - { url = "https://files.pythonhosted.org/packages/49/18/345ff2a98d27cddae40c84cf0466fcc329f3965cd21322bb561a94e4d332/lru_dict-1.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:edc004c88911a8f9715e716116d2520c13db89afd6c37cc0f28042ba10635163", size = 30574, upload-time = "2025-11-02T10:01:13.293Z" }, - { url = "https://files.pythonhosted.org/packages/d7/92/dfea71402a7ca46332bcb854827ee68bbc9be205e2558c3a40293eca9782/lru_dict-1.4.1-cp311-cp311-win32.whl", hash = "sha256:b0b5360264b37676c405ea0a560744d7dcb2d47adff1e7837113c15fabcc7a71", size = 13031, upload-time = "2025-11-02T10:01:13.96Z" }, - { url = "https://files.pythonhosted.org/packages/3a/7b/4c7d566d77ec3ad9128f07407494c2aec57909f8dd59f0c9910bd4c05840/lru_dict-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:bb4b37daad9fe4e796c462f4876cf34e52564630902bdf59a271bc482b48a361", size = 14007, upload-time = "2025-11-02T10:01:14.857Z" }, - { url = "https://files.pythonhosted.org/packages/4f/a8/89e4c26e0e751321b41b0a3007384f97d9eae7a863c49af1c68c43005ca3/lru_dict-1.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:7fa342c6e6bc811ee6a17eb569d37b149340d5aa5a637a53438e316a95783838", size = 16683, upload-time = "2025-11-02T10:01:15.891Z" }, - { url = "https://files.pythonhosted.org/packages/f1/34/b3c6fdd120af68b6eeb524d0de3293ff27918ec57f45eed6bef1789fd085/lru_dict-1.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bd86bd202a7c1585d9dc7e5b0c3d52cf76dc56b261b4bbecfeefbbae31a5c97d", size = 11216, upload-time = "2025-11-02T10:01:16.867Z" }, - { url = "https://files.pythonhosted.org/packages/e9/7e/280267ae23f1ec1074ddaab787c5e041e090220e8e37828d51ff4e681dfd/lru_dict-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4617554f3e42a8f520c8494842c23b98f5b7f4d5e0410e91a4c3ad0ea5f7e094", size = 11687, upload-time = "2025-11-02T10:01:17.485Z" }, - { url = "https://files.pythonhosted.org/packages/ca/18/fec42416ceff98ae2760067ec72b0b9fc02840e729bbc18059c6a02cb01f/lru_dict-1.4.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:40927a6a4284d437047f547e652b15f6f0f40210deb6b9e5b77e556ff0faea0f", size = 31960, upload-time = "2025-11-02T10:01:18.158Z" }, - { url = "https://files.pythonhosted.org/packages/c2/ef/38e7ee1a5d32b9b1629d045fa5a495375383aacfb2945f4d9535b9af9630/lru_dict-1.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2c07ecb6d42494e45d00c2541e6b0ae7659fc3cf89681521ba94b15c682d4fe", size = 32882, upload-time = "2025-11-02T10:01:18.841Z" }, - { url = "https://files.pythonhosted.org/packages/72/82/d56653ca144c291ab37bea5f23c5078ffbe64f7f5b466f91d400590b9106/lru_dict-1.4.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:85b28aa2de7c5f1f6c68221857accd084438df98edbd4f57595795734225770c", size = 34268, upload-time = "2025-11-02T10:01:19.564Z" }, - { url = "https://files.pythonhosted.org/packages/94/ae/382651533d60f0b598757efda56dc87cad5ac311fba8e61f86fb916bf236/lru_dict-1.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:cbbbb4b51e2529ccf7ee8a3c3b834052dbd54871a216cfd229dd2b1194ff293a", size = 32156, upload-time = "2025-11-02T10:01:20.22Z" }, - { url = "https://files.pythonhosted.org/packages/aa/d1/d9df7e9272ccbc96f04c477dfb9abb91fa8fabde86b7fa190cb7b3c7a024/lru_dict-1.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:e47040421a13de8bc6404557b3700c33f1f2683cbcce22fe5cacec4c938ce54b", size = 33395, upload-time = "2025-11-02T10:01:20.901Z" }, - { url = "https://files.pythonhosted.org/packages/e9/6e/dafe0f5943a7b3ab24d3429032ff85873acd626087934b8161b55340c13a/lru_dict-1.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:451f7249866cb9564bb40d73bec7ac865574dafd0a4cc91627bbf35be7e99291", size = 31591, upload-time = "2025-11-02T10:01:21.606Z" }, - { url = "https://files.pythonhosted.org/packages/a6/4d/9dd35444592bfb6805548e15971cfce821400966a51130b78dc021ee8f03/lru_dict-1.4.1-cp312-cp312-win32.whl", hash = "sha256:e8996f3f94870ecb236c55d280839390edae7f201858fee770267eac27b8b47d", size = 13119, upload-time = "2025-11-02T10:01:22.61Z" }, - { url = "https://files.pythonhosted.org/packages/8d/82/7e72e30d6c15d65466b3baca87cce15e20848ba6a488868aa54e901141a6/lru_dict-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:d90774db1b60c0d5c829cfa5d7fda6db96ed1519296f626575598f9f170cca37", size = 14109, upload-time = "2025-11-02T10:01:23.322Z" }, - { url = "https://files.pythonhosted.org/packages/ec/de/18ac3957e1aa6674a0a828748c819265f79b524ff30cbb0ac7f08ab786c8/lru_dict-1.4.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:cc9dd191870555624bbf3903c8afa3f01815ca3256ed8b35cb323f0db3ce4f98", size = 10467, upload-time = "2025-11-02T10:02:05.717Z" }, - { url = "https://files.pythonhosted.org/packages/0c/53/2a0bedaa64950cc56ade72e2f5a292318473585d9a3adc797d13b38082e7/lru_dict-1.4.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:afdf92b332632aa6e4b8646e93723f50f41fece2a80a54d2b44e8ac67f913ceb", size = 10871, upload-time = "2025-11-02T10:02:06.353Z" }, - { url = "https://files.pythonhosted.org/packages/4e/e2/d5ea49d62ea142559fd9cafd8505d4a4f87a1d81953a9c99fa61e7ccbd6b/lru_dict-1.4.1-pp310-pypy310_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3d6770adafae25663b682420891a10a5894595f02b1e4d87766f7adc8e56e72a", size = 12969, upload-time = "2025-11-02T10:02:07.196Z" }, - { url = "https://files.pythonhosted.org/packages/a2/67/0672caac9a04dc9011f7a27fc2ec2003f0bfa008070b29940d05b4dae56a/lru_dict-1.4.1-pp310-pypy310_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:018cd3b41224ca81eb83cdf6db024409a920e5c1d3ce4e8b323cb66e24a73132", size = 13959, upload-time = "2025-11-02T10:02:08.267Z" }, - { url = "https://files.pythonhosted.org/packages/e3/7e/313385214a5011cf9fe8376928f66f70bfedc48d8f7ab424292224ed4907/lru_dict-1.4.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:781dbcf0c83160e525482a4ebcd7c5065851a6c7295f1cda78248a2029f23f39", size = 14084, upload-time = "2025-11-02T10:02:08.993Z" }, - { url = "https://files.pythonhosted.org/packages/8e/47/08c61cad038706b3a89b8c7587ec74ed9731c1e536329745cccb6c840916/lru_dict-1.4.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:9219f13e4101c064f70e1815d7c51f9be9e053983e74dfb7bcfdf92f5fcbb0e0", size = 10384, upload-time = "2025-11-02T10:02:09.656Z" }, - { url = "https://files.pythonhosted.org/packages/6b/a1/022c4d7c68c076370231488c97cf7451131fb9ca0d60d1b2785e7baa1f5b/lru_dict-1.4.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7e1ac7fb6e91e4d3212e153f9e2d98d163a4439b9bf9df247c22519262c26fe", size = 10822, upload-time = "2025-11-02T10:02:10.609Z" }, - { url = "https://files.pythonhosted.org/packages/65/b4/4c0a0877a77fececa9f58d804569e2aac1bfbe588e3a70e79647b5d8f7d4/lru_dict-1.4.1-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:23424321b761c43f3021a596565f8205ecec0e175822e7a5d9b2a175578aa7de", size = 12968, upload-time = "2025-11-02T10:02:11.405Z" }, - { url = "https://files.pythonhosted.org/packages/22/06/d7e393d07dc31e656330d5a058f34e972bf590e7dc882922b426f3aec4a0/lru_dict-1.4.1-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:804ee76f98afc3d50e9a2e9c835a6820877aa6391f2add520a57f86b3f55ec3a", size = 13904, upload-time = "2025-11-02T10:02:12.144Z" }, - { url = "https://files.pythonhosted.org/packages/e8/1e/0eee8bcc16bf01b265ac83e4b870596e2f3bcc40d88aa7ec25407180fe44/lru_dict-1.4.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:3be24e24c8998302ea1c28f997505fa6843f507aad3c7d5c3a82cc01c5c11be4", size = 14062, upload-time = "2025-11-02T10:02:12.878Z" }, -] - [[package]] name = "mako" version = "1.3.10" @@ -1211,19 +1065,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] -[[package]] -name = "mip" -version = "1.17.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cbcbox" }, - { name = "cffi" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d3/69/8b7695d78b96e997a691814e992732d98fa0f92c5c2a2885ec607f759aba/mip-1.17.4.tar.gz", hash = "sha256:0e7ca54424614bb9670795cc22cb7f700baf5a12c59bbc25af10b723bf0b64eb", size = 9443521, upload-time = "2026-03-12T16:44:59.218Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2b/1e/b02b0c57c2304944c371be03eb0b4293f66c81048e09fc5ce11042225b2b/mip-1.17.4-py3-none-any.whl", hash = "sha256:3b140b2954a50595ad4f9f263027087add3d22129ea29382d2ea2bacf79b485d", size = 88148, upload-time = "2026-03-12T16:44:57.755Z" }, -] - [[package]] name = "ml-dtypes" version = "0.5.4" @@ -1725,16 +1566,10 @@ all = [ { name = "datasets" }, { name = "deepspeed", marker = "sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "diffusers" }, - { name = "fire" }, { name = "huggingface-hub" }, - { name = "hydra-core" }, - { name = "immutabledict" }, { name = "lief" }, - { name = "lru-dict" }, - { name = "mip" }, { name = "ml-dtypes" }, { name = "nltk" }, - { name = "omegaconf" }, { name = "onnx" }, { name = "onnx-graphsurgeon" }, { name = "onnxconverter-common" }, @@ -1744,13 +1579,10 @@ all = [ { name = "onnxruntime-gpu", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "onnxscript" }, { name = "onnxslim" }, - { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "peft" }, { name = "polygraphy" }, { name = "sentencepiece" }, { name = "transformers" }, - { name = "typeguard" }, { name = "wonderwords" }, ] dev = [ @@ -1762,17 +1594,11 @@ dev = [ { name = "datasets" }, { name = "deepspeed", marker = "sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "diffusers" }, - { name = "fire" }, { name = "huggingface-hub" }, - { name = "hydra-core" }, - { name = "immutabledict" }, { name = "lief" }, - { name = "lru-dict" }, - { name = "mip" }, { name = "ml-dtypes" }, { name = "mypy" }, { name = "nltk" }, - { name = "omegaconf" }, { name = "onnx" }, { name = "onnx-graphsurgeon" }, { name = "onnxconverter-common" }, @@ -1782,8 +1608,6 @@ dev = [ { name = "onnxruntime-gpu", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "onnxscript" }, { name = "onnxslim" }, - { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "peft" }, { name = "polygraphy" }, { name = "pre-commit" }, @@ -1808,7 +1632,6 @@ dev = [ { name = "tox" }, { name = "tox-current-env" }, { name = "transformers" }, - { name = "typeguard" }, { name = "wonderwords" }, ] dev-docs = [ @@ -1868,17 +1691,6 @@ onnx = [ { name = "onnxslim" }, { name = "polygraphy" }, ] -puzzletron = [ - { name = "fire" }, - { name = "hydra-core" }, - { name = "immutabledict" }, - { name = "lru-dict" }, - { name = "mip" }, - { name = "omegaconf" }, - { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "typeguard" }, -] [package.metadata] requires-dist = [ @@ -1890,13 +1702,8 @@ requires-dist = [ { name = "datasets", marker = "extra == 'hf'", specifier = ">=3.0.0" }, { name = "deepspeed", marker = "sys_platform != 'darwin' and sys_platform != 'win32' and extra == 'hf'", specifier = ">=0.9.6" }, { name = "diffusers", marker = "extra == 'hf'", specifier = ">=0.32.2" }, - { name = "fire", marker = "extra == 'puzzletron'" }, { name = "huggingface-hub", marker = "extra == 'hf'", specifier = ">=0.24.0" }, - { name = "hydra-core", marker = "extra == 'puzzletron'", specifier = "==1.3.2" }, - { name = "immutabledict", marker = "extra == 'puzzletron'" }, { name = "lief", marker = "extra == 'onnx'" }, - { name = "lru-dict", marker = "extra == 'puzzletron'" }, - { name = "mip", marker = "extra == 'puzzletron'" }, { name = "ml-dtypes", marker = "extra == 'onnx'" }, { name = "mypy", marker = "extra == 'dev-lint'", specifier = "==1.17.1" }, { name = "ninja" }, @@ -1904,8 +1711,8 @@ requires-dist = [ { name = "numpy" }, { name = "nvidia-ml-py", specifier = ">=12" }, { name = "nvidia-modelopt", extras = ["all", "dev-docs", "dev-lint", "dev-test"], marker = "extra == 'dev'" }, - { name = "nvidia-modelopt", extras = ["hf", "onnx", "puzzletron"], marker = "extra == 'all'" }, - { name = "omegaconf", marker = "extra == 'puzzletron'", specifier = "==2.3.0" }, + { name = "nvidia-modelopt", extras = ["hf", "onnx"], marker = "extra == 'all'" }, + { name = "omegaconf", specifier = ">=2.3.0" }, { name = "onnx", marker = "extra == 'onnx'", specifier = "~=1.19.0" }, { name = "onnx-graphsurgeon", marker = "extra == 'onnx'" }, { name = "onnxconverter-common", marker = "extra == 'onnx'", specifier = "~=1.16.0" }, @@ -1917,7 +1724,6 @@ requires-dist = [ { name = "onnxscript", marker = "extra == 'onnx'" }, { name = "onnxslim", marker = "extra == 'onnx'", specifier = ">=0.1.76" }, { name = "packaging" }, - { name = "pandas", marker = "extra == 'puzzletron'" }, { name = "peft", marker = "extra == 'hf'", specifier = ">=0.17.0" }, { name = "polygraphy", marker = "extra == 'onnx'", specifier = ">=0.49.22" }, { name = "pre-commit", marker = "extra == 'dev-lint'", specifier = "==4.3.0" }, @@ -1951,23 +1757,9 @@ requires-dist = [ { name = "tox-current-env", marker = "extra == 'dev-test'", specifier = ">=0.0.12" }, { name = "tqdm" }, { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53,<5.0" }, - { name = "typeguard", marker = "extra == 'puzzletron'" }, { name = "wonderwords", marker = "extra == 'hf'" }, ] -provides-extras = ["onnx", "hf", "puzzletron", "dev-lint", "dev-docs", "dev-test", "all", "dev"] - -[[package]] -name = "omegaconf" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "antlr4-python3-runtime" }, - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, -] +provides-extras = ["onnx", "hf", "dev-lint", "dev-docs", "dev-test", "all", "dev"] [[package]] name = "omegaconf" @@ -2680,15 +2472,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/c5/e98d9c51f3d5300d5e40ad9037dd6b3b60736fd02ab68dcc98c96be7592d/pybind11-3.0.2-py3-none-any.whl", hash = "sha256:f8a6500548919cc33bcd220d5f984688326f574fa97f1107f2f4fdb4c6fb019f", size = 310158, upload-time = "2026-02-17T04:46:49.91Z" }, ] -[[package]] -name = "pycparser" -version = "3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1b/7d/92392ff7815c21062bea51aa7b87d45576f649f16458d78b7cf94b9ab2e6/pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29", size = 103492, upload-time = "2026-01-21T14:26:51.89Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992", size = 48172, upload-time = "2026-01-21T14:26:50.693Z" }, -] - [[package]] name = "pydantic" version = "2.12.5" @@ -3612,15 +3395,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] -[[package]] -name = "termcolor" -version = "3.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/46/79/cf31d7a93a8fdc6aa0fbb665be84426a8c5a557d9240b6239e9e11e35fc5/termcolor-3.3.0.tar.gz", hash = "sha256:348871ca648ec6a9a983a13ab626c0acce02f515b9e1983332b17af7979521c5", size = 14434, upload-time = "2025-12-29T12:55:21.882Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/d1/8bb87d21e9aeb323cc03034f5eaf2c8f69841e40e4853c2627edf8111ed3/termcolor-3.3.0-py3-none-any.whl", hash = "sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5", size = 7734, upload-time = "2025-12-29T12:55:20.718Z" }, -] - [[package]] name = "timm" version = "1.0.25" @@ -3864,18 +3638,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" }, ] -[[package]] -name = "typeguard" -version = "4.5.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/2b/e8/66e25efcc18542d58706ce4e50415710593721aae26e794ab1dec34fb66f/typeguard-4.5.1.tar.gz", hash = "sha256:f6f8ecbbc819c9bc749983cc67c02391e16a9b43b8b27f15dc70ed7c4a007274", size = 80121, upload-time = "2026-02-19T16:09:03.392Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/91/88/b55b3117287a8540b76dbdd87733808d4d01c8067a3b339408c250bb3600/typeguard-4.5.1-py3-none-any.whl", hash = "sha256:44d2bf329d49a244110a090b55f5f91aa82d9a9834ebfd30bcc73651e4a8cc40", size = 36745, upload-time = "2026-02-19T16:09:01.6Z" }, -] - [[package]] name = "typing-extensions" version = "4.15.0" From d0209dc606b6ba6ddd63560f9a510191f70ef95c Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 25 Mar 2026 12:30:45 -0700 Subject: [PATCH 54/62] Make Qwen and QwenVL descriptor generic so can be used for other variants Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../pruning/ffn_pruning.yaml | 4 +-- .../puzzletron/anymodel/models/__init__.py | 4 +-- .../__init__.py | 8 ++---- .../qwen3_converter.py} | 2 +- .../qwen3_model_descriptor.py} | 6 ++-- .../models/{qwen3_8b => qwen3_vl}/__init__.py | 6 ++-- .../qwen3_vl_converter.py} | 2 +- .../qwen3_vl_model_descriptor.py} | 11 ++++---- .../mbridge_distillation/test_distill_hf.py | 28 +++++++------------ .../Qwen/Qwen3-8B/pruning/ffn_pruning.yaml | 2 +- .../pruning/expert_pruning.yaml | 4 +-- 11 files changed, 33 insertions(+), 44 deletions(-) rename modelopt/torch/puzzletron/anymodel/models/{qwen3_vl_30b_a3b_instruct => qwen3}/__init__.py (67%) rename modelopt/torch/puzzletron/anymodel/models/{qwen3_8b/qwen3_8b_converter.py => qwen3/qwen3_converter.py} (97%) rename modelopt/torch/puzzletron/anymodel/models/{qwen3_8b/qwen3_8b_model_descriptor.py => qwen3/qwen3_model_descriptor.py} (96%) rename modelopt/torch/puzzletron/anymodel/models/{qwen3_8b => qwen3_vl}/__init__.py (78%) rename modelopt/torch/puzzletron/anymodel/models/{qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_converter.py => qwen3_vl/qwen3_vl_converter.py} (98%) rename modelopt/torch/puzzletron/anymodel/models/{qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_model_descriptor.py => qwen3_vl/qwen3_vl_model_descriptor.py} (95%) diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml index 70dd5fd006..93590d13e5 100644 --- a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml @@ -6,7 +6,7 @@ activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activati pruning_mixin: _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn layer_descriptor: - _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor.Qwen3_8BFFNIntermediateLayerDescriptor + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} activation_hooks_kwargs: @@ -14,5 +14,5 @@ activation_hooks_kwargs: target_layer: "mlp.down_proj" layer_input_descriptors_path: -intermediate_size_list: [256] # teacher_intermediate_size is 14336 +intermediate_size_list: [256] # teacher_intermediate_size is 14336 mlp_init_mode: "PruneByActivationsLog" diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py index 34d7ce5e5a..4c68dbc823 100644 --- a/modelopt/torch/puzzletron/anymodel/models/__init__.py +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -20,5 +20,5 @@ from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import * from modelopt.torch.puzzletron.anymodel.models.qwen2 import * -from modelopt.torch.puzzletron.anymodel.models.qwen3_8b import * -from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct import * +from modelopt.torch.puzzletron.anymodel.models.qwen3 import * +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl import * diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py similarity index 67% rename from modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/__init__.py rename to modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py index 7bf317d29e..cf28475718 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/__init__.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_converter import ( - Qwen3VL30BA3BInstructConverter, -) -from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_model_descriptor import ( - Qwen3VL30BA3BInstructModelDescriptor, +from modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_converter import Qwen3Converter +from modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor import ( + Qwen3ModelDescriptor, ) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py similarity index 97% rename from modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_converter.py rename to modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py index 1a389291df..830c7ba960 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_converter.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py @@ -28,7 +28,7 @@ @ConverterFactory.register_decorator("qwen3") -class Qwen3_8BConverter(Converter): +class Qwen3Converter(Converter): @staticmethod def create_block_configs_from_main_config(config: Qwen3Config) -> List[BlockConfig]: num_hidden_layers = config.num_hidden_layers diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py similarity index 96% rename from modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_model_descriptor.py rename to modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py index 679ee73fae..ae70d96617 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/qwen3_8b_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py @@ -44,7 +44,7 @@ @ModelDescriptorFactory.register_decorator("qwen3") -class Qwen3_8BModelDescriptor(ModelDescriptor): +class Qwen3ModelDescriptor(ModelDescriptor): @staticmethod def decoder_layer_cls(): return Qwen3DecoderLayer @@ -135,7 +135,7 @@ def build_attention_predicates() -> Dict[str, re.Pattern]: @dataclass -class Qwen3_8BFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): +class Qwen3FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): down_proj_name: str = "mlp.down_proj" ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" linear_weight_names: List[str] = field( @@ -144,7 +144,7 @@ class Qwen3_8BFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): @dataclass -class Qwen3_8BKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): +class Qwen3KVHeadsLayerDescriptor(KVHeadsLayerDescriptor): o_proj_name: str = "self_attn.o_proj" attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" qkvo_weight_names: List[str] = field( diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py similarity index 78% rename from modelopt/torch/puzzletron/anymodel/models/qwen3_8b/__init__.py rename to modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py index 0f753f705d..48dbd2de8a 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_8b/__init__.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_converter import Qwen3_8BConverter -from modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor import ( - Qwen3_8BModelDescriptor, +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_converter import Qwen3VLConverter +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor import ( + Qwen3VLModelDescriptor, ) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py similarity index 98% rename from modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_converter.py rename to modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py index 0c50dfeb9e..82e51b7b80 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_converter.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py @@ -29,7 +29,7 @@ @ConverterFactory.register_decorator("qwen3_vl") -class Qwen3VL30BA3BInstructConverter(Converter): +class Qwen3VLConverter(Converter): @staticmethod def create_block_configs_from_main_config(config: Qwen3VLMoeConfig) -> List[BlockConfig]: # Qwen3-VL MoE has nested text_config diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py similarity index 95% rename from modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_model_descriptor.py rename to modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py index 7c7665a644..7a1641969a 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl_30b_a3b_instruct/qwen3_vl_30b_a3b_instruct_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py @@ -19,7 +19,6 @@ from dataclasses import dataclass, field from typing import Dict, List -import torch.nn as nn from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( Qwen3VLMoeTextDecoderLayer, Qwen3VLMoeTextRotaryEmbedding, @@ -46,7 +45,7 @@ @ModelDescriptorFactory.register_decorator("qwen3_vl") -class Qwen3VL30BA3BInstructModelDescriptor(ModelDescriptor): +class Qwen3VLModelDescriptor(ModelDescriptor): @staticmethod def uses_autocast() -> bool: """ @@ -90,7 +89,7 @@ def mlp_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer): @staticmethod def init_rotary_embedding(model, runtime): # Re-initialize text rotary embedding on correct device and dtype - text_config = Qwen3VL30BA3BInstructModelDescriptor.get_language_model_config(model.config) + text_config = Qwen3VLModelDescriptor.get_language_model_config(model.config) model.model.language_model.rotary_emb = Qwen3VLMoeTextRotaryEmbedding( config=text_config ).to(device=runtime.device, dtype=runtime.dtype) @@ -171,7 +170,7 @@ def build_attention_predicates() -> Dict[str, re.Pattern]: @dataclass -class Qwen3VL30BA3BInstructFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): +class Qwen3VLFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): down_proj_name: str = "mlp.down_proj" ffn_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp" linear_weight_names: List[str] = field( @@ -180,7 +179,7 @@ class Qwen3VL30BA3BInstructFFNIntermediateLayerDescriptor(FFNIntermediateLayerDe @dataclass -class Qwen3VL30BA3BInstructKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): +class Qwen3VLKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): o_proj_name: str = "self_attn.o_proj" attn_prefix_name: str = "model.language_model.layers.{layer_idx}.self_attn" qkvo_weight_names: List[str] = field( @@ -189,7 +188,7 @@ class Qwen3VL30BA3BInstructKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): @dataclass -class Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): +class Qwen3VLExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): """ Qwen3-VL MoE layer descriptor. diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py index 7b0c9a32f6..1bd0037510 100644 --- a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -28,7 +28,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): """Integration test for distill_hf.py. - Creates Llama models programmatically, converts them to heterogeneous format (AnyModel), + Creates Qwen3 models programmatically, converts them to heterogeneous format (AnyModel), and runs mbridge distillation. The models are created with reduced size for faster testing. Models are converted to include block_configs. """ @@ -59,7 +59,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): cmd_parts, student_hf_path=student_hf_path, teacher_hf_path=teacher_hf_path, - output_dir=str(output_dir), + output_dir=output_dir, tp_size=tp_size, pp_size=1, seq_length=128, @@ -73,8 +73,8 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): eval_interval=100, eval_iters=0, log_interval=5, - hf_export_path=str(hf_export_dir), - hf_model="meta-llama/Llama-3.1-8B-Instruct", + hf_export_path=hf_export_dir, + hf_model="Qwen/Qwen3-0.6B", ) run_example_command(cmd_parts, example_path="puzzletron/mbridge_distillation") @@ -84,11 +84,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): assert run_config_path.exists(), f"Expected run_config.yaml to exist at: {run_config_path}" # Verify that the distilled model can be loaded in HuggingFace format - model = AutoModelForCausalLM.from_pretrained( - str(hf_export_dir), - local_files_only=True, - trust_remote_code=True, - ) + model = AutoModelForCausalLM.from_pretrained(str(hf_export_dir)) assert model is not None, "Failed to load distilled model with AutoModelForCausalLM" print( @@ -100,7 +96,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): def _prepare_student_and_teacher_models(project_root_path: Path, tmp_path: Path) -> tuple[str, str]: """Prepare student and teacher models for distillation. - Creates Llama models programmatically, converts them to heterogeneous format (AnyModel), + Creates Qwen3 models programmatically, converts them to heterogeneous format (AnyModel), and returns the paths to the converted checkpoints. Args: @@ -124,7 +120,7 @@ def _prepare_student_and_teacher_models(project_root_path: Path, tmp_path: Path) output_path=str(student_hf_dir), vocab_size=tokenizer.vocab_size, tokenizer=tokenizer, - hf_model_name="meta-llama/Llama-3.1-8B-Instruct", + hf_model_name="Qwen/Qwen3-0.6B", hybrid_override_pattern=None, ) @@ -133,7 +129,7 @@ def _prepare_student_and_teacher_models(project_root_path: Path, tmp_path: Path) output_path=str(teacher_hf_dir), vocab_size=tokenizer.vocab_size, tokenizer=tokenizer, - hf_model_name="meta-llama/Llama-3.1-8B-Instruct", + hf_model_name="Qwen/Qwen3-0.6B", hybrid_override_pattern=None, ) @@ -143,15 +139,11 @@ def _prepare_student_and_teacher_models(project_root_path: Path, tmp_path: Path) teacher_anymodel_dir = tmp_path / "teacher_anymodel" convert_model( - input_dir=str(student_hf_dir), - output_dir=str(student_anymodel_dir), - converter="llama", + input_dir=str(student_hf_dir), output_dir=str(student_anymodel_dir), converter="qwen3" ) convert_model( - input_dir=str(teacher_hf_dir), - output_dir=str(teacher_anymodel_dir), - converter="llama", + input_dir=str(teacher_hf_dir), output_dir=str(teacher_anymodel_dir), converter="qwen3" ) print("Models converted to AnyModel format:") print(f" Student AnyModel: {student_anymodel_dir}") diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml index e6e6ce5bb4..6bfeec715c 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml @@ -4,4 +4,4 @@ defaults: pruning_mixin: layer_descriptor: - _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor.Qwen3_8BFFNIntermediateLayerDescriptor + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml index bc1124617e..4e0786dc7a 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml @@ -6,14 +6,14 @@ activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruni pruning_mixin: _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn layer_descriptor: - _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_model_descriptor.Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor.Qwen3VLExpertRemovalLayerDescriptor target_name: "mlp" hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.Qwen3VLRemoveExpertsIndependentHook} activation_hooks_kwargs: # num_experts_to_keep must be >= num_experts_per_tok (can't route to more experts than exist) -num_experts_to_keep_list: [8] # num_experts in test model is 16, num_experts_per_tok is 8 +num_experts_to_keep_list: [8] # num_experts in test model is 16, num_experts_per_tok is 8 mlp_init_mode: "ExpertRemoval" mlp_init_config_yaml: expert_scores_key: "expert_ranks_mse" From d987bad2a4e84c233b7fb764d937c50e2ee45ba0 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 30 Mar 2026 08:07:26 -0700 Subject: [PATCH 55/62] Set strict=True in distill_hf export Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../export/mbridge/export_mbridge_to_hf.py | 10 +-------- .../mbridge_distillation/test_distill_hf.py | 22 ++++++++----------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py index 59e1d7dade..0ab6083f77 100644 --- a/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py +++ b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py @@ -63,16 +63,8 @@ def export_to_hf_and_copy_config( bridge = AutoBridge.from_hf_pretrained(hf_model, trust_remote_code=trust_remote_code) print_rank_0("📤 Exporting to HuggingFace format...") - # Use strict=False for test_distill_hf.py which uses small models (2 layers) with fewer layers - # than the template model (32 layers). This allows partial exports when some tensors are missing. - # Note: This is NOT needed when running on real compressed puzzletron student models, - # which have the same number of layers as the template model (some may be skipped via no_op - # in block_configs, but all layer tensors are still present in the checkpoint). bridge.export_ckpt( - megatron_path=megatron_path, - hf_path=hf_export_path, - show_progress=True, - strict=False, # Needed for test_distill_hf.py small models; not needed for real compressed models + megatron_path=megatron_path, hf_path=hf_export_path, show_progress=True, strict=True ) print_rank_0(f"✅ Successfully exported model to: {hf_export_path}") diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py index 1bd0037510..cb649d6c1c 100644 --- a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -33,8 +33,8 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): Models are converted to include block_configs. """ # Prepare student and teacher models - student_hf_path, teacher_hf_path = _prepare_student_and_teacher_models( - project_root_path, tmp_path + student_hf_dir, student_anymodel_dir, _, teacher_anymodel_dir = ( + _prepare_student_and_teacher_models(project_root_path, tmp_path) ) output_dir = tmp_path / "distill_output" @@ -57,8 +57,8 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): ] extend_cmd_parts( cmd_parts, - student_hf_path=student_hf_path, - teacher_hf_path=teacher_hf_path, + student_hf_path=student_anymodel_dir, + teacher_hf_path=teacher_anymodel_dir, output_dir=output_dir, tp_size=tp_size, pp_size=1, @@ -74,7 +74,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): eval_iters=0, log_interval=5, hf_export_path=hf_export_dir, - hf_model="Qwen/Qwen3-0.6B", + hf_model=student_hf_dir, ) run_example_command(cmd_parts, example_path="puzzletron/mbridge_distillation") @@ -93,18 +93,14 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): ) -def _prepare_student_and_teacher_models(project_root_path: Path, tmp_path: Path) -> tuple[str, str]: +def _prepare_student_and_teacher_models( + project_root_path: Path, tmp_path: Path +) -> tuple[Path, Path, Path, Path]: """Prepare student and teacher models for distillation. Creates Qwen3 models programmatically, converts them to heterogeneous format (AnyModel), and returns the paths to the converted checkpoints. - Args: - project_root_path: Path to the project root directory - tmp_path: Temporary directory for test artifacts - - Returns: - Tuple of (student_hf_path, teacher_hf_path) as strings """ # Create temporary directories for models @@ -149,4 +145,4 @@ def _prepare_student_and_teacher_models(project_root_path: Path, tmp_path: Path) print(f" Student AnyModel: {student_anymodel_dir}") print(f" Teacher AnyModel: {teacher_anymodel_dir}") - return student_anymodel_dir, teacher_anymodel_dir + return student_hf_dir, student_anymodel_dir, teacher_hf_dir, teacher_anymodel_dir From 75651cc157596c2185d6556b3a93f635e7f5584a Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:30:45 -0700 Subject: [PATCH 56/62] add basic ruff fixes Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .pre-commit-config.yaml | 3 +-- .../nemotron_h/nemotron_h_model_descriptor.py | 1 - .../nemotron_h_v2_model_descriptor.py | 1 - .../models/qwen2/qwen2_model_descriptor.py | 2 -- .../models/qwen3_vl/qwen3_vl_model_descriptor.py | 2 +- modelopt/torch/puzzletron/mip/run_puzzle.py | 7 ++++--- modelopt/torch/puzzletron/mip/utils.py | 2 -- modelopt/torch/puzzletron/sewing_kit/core.py | 2 -- modelopt/torch/puzzletron/sewing_kit/utils.py | 15 +++++++++------ .../subblock_stats/calc_subblock_stats.py | 1 - .../bypassed_training/init_child_from_parent.py | 3 +-- pyproject.toml | 9 --------- .../mbridge_distillation/test_distill_hf.py | 2 +- 13 files changed, 17 insertions(+), 33 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd7f922fbb..7810db7886 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,6 +79,7 @@ repos: modelopt/onnx/quantization/ort_patching.py| modelopt/torch/_deploy/utils/onnx_utils.py| modelopt/torch/export/transformer_engine.py| + modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py| modelopt/torch/quantization/export_onnx.py| modelopt/torch/quantization/plugins/attention.py| modelopt/torch/speculative/eagle/utils.py| @@ -100,8 +101,6 @@ repos: examples/speculative_decoding/main.py| examples/speculative_decoding/medusa_utils.py| examples/speculative_decoding/server_generate.py| - examples/puzzletron/evaluation/lm_eval_anymodel.py| - modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py| experimental/dms/models/qwen3/configuration_qwen3_dms.py| experimental/dms/models/qwen3/modeling_qwen3_dms.py| )$ diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py index 55d9ef56ca..7687d57c83 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -151,7 +151,6 @@ def init_rotary_embedding(model, runtime): """ NemotronH has no positional embeddings """ - pass @staticmethod def input_embedding_name(): diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py index f50217d4d3..c8c89658bf 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -131,7 +131,6 @@ def init_rotary_embedding(model, runtime): """ NemotronH has no positional embeddings """ - pass @staticmethod def input_embedding_name(): diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py index 69185d1de3..c2bbeed7a9 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py @@ -144,5 +144,3 @@ class Qwen2FFNIntermediateLayerDescriptor(LlamaFFNIntermediateLayerDescriptor): Qwen2 uses the same FFN structure as Llama (gate_proj, up_proj, down_proj). """ - - pass diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py index 7a1641969a..8c182c8968 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py @@ -202,7 +202,7 @@ class Qwen3VLExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): moe_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp" # Router: Qwen3VLMoeTextTopKRouter has self.weight, no bias router_weights: List[str] = field(default_factory=lambda: ["gate.weight"]) - router_biases: List[str] = field(default_factory=lambda: []) + router_biases: List[str] = field(default_factory=list) # Fused expert format: Qwen3VLMoeTextExperts stores all experts in single tensors # with shape [num_experts, ...] instead of separate tensors per expert. is_fused_experts: bool = True diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py index 71913db7d3..27bd4bea12 100644 --- a/modelopt/torch/puzzletron/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -20,6 +20,7 @@ import dataclasses import enum import json +import sys from collections.abc import Hashable, Iterable from copy import deepcopy from pathlib import Path @@ -401,14 +402,14 @@ def _assert_valid_config(args, puzzle_profile): missing_args = [arg for arg in required_args if arg not in args or getattr(args, arg) is None] if missing_args: mprint(f"error: The following arguments are required: {', '.join(missing_args)}") - exit(1) + sys.exit(1) # Make sure we have specified subblock_stats_args if "subblock_stats_args" not in args and "subblock_stats_args" not in puzzle_profile: mprint( "error: Must specify `subblock_stats_arrs` in either puzzle_profile or as a commandline arg." ) - exit(1) + sys.exit(1) # Make sure we have specified constraints if ( @@ -420,7 +421,7 @@ def _assert_valid_config(args, puzzle_profile): mprint( "error: Must specify either `mip_constraints` or `human_constraints` in one of puzzle_profile or as a commandline argument." ) - exit(1) + sys.exit(1) def _get_minimal_unique_names(dicts: list[dict]) -> list[str]: diff --git a/modelopt/torch/puzzletron/mip/utils.py b/modelopt/torch/puzzletron/mip/utils.py index 7398203cc2..b276ff33b1 100644 --- a/modelopt/torch/puzzletron/mip/utils.py +++ b/modelopt/torch/puzzletron/mip/utils.py @@ -21,8 +21,6 @@ class InfeasibleError(Exception): """Exception raised when optimization problem is infeasible.""" - pass - def sort_replacements(layer_replacements: list[dict]) -> list[dict]: """Sort layer replacements by parent layer indices. diff --git a/modelopt/torch/puzzletron/sewing_kit/core.py b/modelopt/torch/puzzletron/sewing_kit/core.py index 41eaeee75f..fb9055c3ed 100644 --- a/modelopt/torch/puzzletron/sewing_kit/core.py +++ b/modelopt/torch/puzzletron/sewing_kit/core.py @@ -676,8 +676,6 @@ def forward( if work is not None: work.wait() - pass - if len(node.stitches_from) > 0: assert len(peers) == 1, ( f"Cannot use multiple peers when using RemoteTarget as a source ({peers=})" diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 19c1bd6c83..068abef99e 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -16,9 +16,9 @@ from __future__ import annotations import inspect -from collections.abc import Sequence from contextlib import contextmanager from typing import ( + TYPE_CHECKING, Any, Callable, ContextManager, @@ -43,6 +43,9 @@ from torch._subclasses import FakeTensor, FakeTensorMode from typing_extensions import override +if TYPE_CHECKING: + from collections.abc import Sequence + Fn = TypeVar("Fn", bound=Callable) @@ -61,11 +64,11 @@ def __call__(self, fn: Fn, disable: bool = False) -> Fn: ... try: - dynamo_skip: DynamoSkip = cast(Any, torch._dynamo.decorators).skip - dynamo_disable: DynamoDisable = cast(Any, torch._dynamo.decorators).disable + dynamo_skip: DynamoSkip = cast("Any", torch._dynamo.decorators).skip + dynamo_disable: DynamoDisable = cast("Any", torch._dynamo.decorators).disable except: - dynamo_skip: DynamoSkip = cast(Any, torch._dynamo.eval_frame).skip - dynamo_disable: DynamoDisable = cast(Any, torch._dynamo.eval_frame).disable + dynamo_skip: DynamoSkip = cast("Any", torch._dynamo.eval_frame).skip + dynamo_disable: DynamoDisable = cast("Any", torch._dynamo.eval_frame).disable TModule = TypeVar("TModule", bound=nn.Module) @@ -264,7 +267,7 @@ def __new__(cls, elem, device) -> MyFakeTensor: dispatch_device=True, device_for_backend_keys=device, ) - return cast(MyFakeTensor, self) + return cast("MyFakeTensor", self) @classmethod @dynamo_disable diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index cb178e0566..80d5216987 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -203,7 +203,6 @@ def calculate_subblock_stats( ) if is_calc_runtime: - pass # TODO: fix # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \ diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index c4d2ea054e..5c6c25681d 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -70,8 +70,7 @@ def init_child_from_parent( - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers)) """ assert ( - gqa_init_mode != GQAInitMode.RandomKV - and gqa_init_mode != GQAInitMode.RandomBlock + gqa_init_mode not in [GQAInitMode.RandomKV, GQAInitMode.RandomBlock] and mlp_init_mode != MlpInitMode.Random and linear_init_mode != LinearInitMode.Random ), ( diff --git a/pyproject.toml b/pyproject.toml index d88ed5b807..e9d40f466a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,20 +218,11 @@ extend-ignore = [ "D", "E", "F", - "FURB", - "ISC", "N", "PERF", - "PGH", - "PIE", - "PLE", - "PLR", - "PT", "RUF", "SIM", - "TC", "UP", - "W", ] # TODO: Disabled for now, will enable later, once all puzzletron code is migrated "modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style "modelopt/torch/sparsity/attention_sparsity/kernels/*" = [ diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py index cb649d6c1c..db886f5050 100644 --- a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -84,7 +84,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): assert run_config_path.exists(), f"Expected run_config.yaml to exist at: {run_config_path}" # Verify that the distilled model can be loaded in HuggingFace format - model = AutoModelForCausalLM.from_pretrained(str(hf_export_dir)) + model = AutoModelForCausalLM.from_pretrained(hf_export_dir) assert model is not None, "Failed to load distilled model with AutoModelForCausalLM" print( From 03118ce1de5b9165c0bb6d8b3d455029d5ad2583 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 30 Mar 2026 08:34:22 -0700 Subject: [PATCH 57/62] Apply coderabbit suggestions Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../anymodel/models/qwen3/qwen3_converter.py | 15 +++++++++------ modelopt/torch/puzzletron/mip/run_puzzle.py | 10 +++++----- .../subblock_stats/calc_subblock_stats.py | 4 +++- .../bypassed_training/init_child_from_parent.py | 6 +++++- .../torch/puzzletron/tools/checkpoint_utils.py | 5 ++--- .../torch/puzzletron/tools/checkpoint_utils_hf.py | 2 -- modelopt/torch/puzzletron/tools/logger.py | 2 +- .../validate_puzzle_with_multi_replacements.py | 6 +++++- 8 files changed, 30 insertions(+), 20 deletions(-) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py index 830c7ba960..bad9bb47d6 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py @@ -33,10 +33,13 @@ class Qwen3Converter(Converter): def create_block_configs_from_main_config(config: Qwen3Config) -> List[BlockConfig]: num_hidden_layers = config.num_hidden_layers - block_config = BlockConfig( - attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), - ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), - ).to_dict() - - block_configs = [block_config] * num_hidden_layers + block_configs = [ + BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + for _ in range(num_hidden_layers) + ] return block_configs diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py index 27bd4bea12..dd50003d1b 100644 --- a/modelopt/torch/puzzletron/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -156,7 +156,7 @@ def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]: return self.constraints assert all(key in subblock_stats_args for key in ("batch_size", "generation_seq_len")), ( - "Can't realize human constraints without 'block_size' and 'generation_seq_len' in subblock_stats_args." + "Can't realize human constraints without 'batch_size' and 'generation_seq_len' in subblock_stats_args." ) batch_size = subblock_stats_args["batch_size"] generation_seq_len = subblock_stats_args["generation_seq_len"] @@ -192,7 +192,7 @@ def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]: return mip_constraints -def parse_args() -> argparse.Namespace: +def parse_args() -> DictConfig: parser = argparse.ArgumentParser() parser.add_argument("--puzzle_profile", type=parse_path) @@ -228,11 +228,11 @@ def parse_args() -> argparse.Namespace: ) args = parser.parse_args() - return args + return DictConfig(vars(args)) def run_single_puzzle_config( - args: argparse.Namespace | DictConfig, + args: DictConfig, gathered_metrics: dict, subblock_stats: dict, subblock_stats_args: dict, @@ -432,7 +432,7 @@ def _get_minimal_unique_names(dicts: list[dict]) -> list[str]: return ["-".join(f"{k}_{d[k]}".replace(".", "_") for k in non_common_keys) for d in dicts] -def run_puzzle(args: argparse.Namespace | DictConfig) -> list[str]: +def run_puzzle(args: DictConfig) -> list[str]: # Loads config from args/puzzle_profile if args.puzzle_profile is not None: with open(args.puzzle_profile) as f: diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index 80d5216987..549d994f07 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -19,6 +19,7 @@ import dataclasses import json import os +import warnings from functools import partial from itertools import product from pathlib import Path @@ -118,6 +119,7 @@ def calculate_subblock_stats( } # Compute runtime stats for unique subblocks only if is_calc_runtime: + raise NotImplementedError("Runtime stats calculation is not implemented yet") subblock_configs_nolayerindex = set( [subblock_config["subblock_config"] for subblock_config in subblock_configs] ) @@ -314,7 +316,7 @@ def calculate_subblock_stats_for_puzzle_dir( moe_stats_file = master_puzzle_dir / moe_stats_filename if not moe_stats_file.exists(): - Warning( + warnings.warn( f"MOE stats file {moe_stats_file} does not exist, can't calculate num active params" ) moe_stats_file = None diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index 5c6c25681d..783d233c3e 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -79,7 +79,11 @@ def init_child_from_parent( descriptor = ModelDescriptorFactory.get(descriptor) - copy_tokenizer(parent_checkpoint_dir, output_checkpoint_dir) + copy_tokenizer( + parent_checkpoint_dir, + output_checkpoint_dir, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) parent_model_config = load_model_config( parent_checkpoint_dir, trust_remote_code=descriptor.requires_trust_remote_code() diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py index 0ef4bfa472..8b14bd8027 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py @@ -35,8 +35,6 @@ PTH_SUBBLOCKS_DIR_NAME = "subblocks" STATE_DICT_FILE_NAME = "model.pth" -warnings.filterwarnings("ignore", "You are using `torch.load` with `weights_only=False`*.") - def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: checkpoint_dir = _normalize_checkpoint_dir(checkpoint_dir) @@ -162,6 +160,7 @@ def copy_tokenizer( source_dir_or_tokenizer_name: Path | str, target_dir: Path | str, on_failure: Literal["raise", "warn"] = "raise", + trust_remote_code: bool = False, ) -> None: """Prefer loading the tokenizer from huggingface hub (when tokenizer_name.txt file is available) to avoid collision between transformers versions. @@ -173,7 +172,7 @@ def copy_tokenizer( tokenizer = None try: tokenizer = AutoTokenizer.from_pretrained( - source_dir_or_tokenizer_name, trust_remote_code=True + source_dir_or_tokenizer_name, trust_remote_code=trust_remote_code ) except Exception: message = f"Couldn't load tokenizer from '{source_dir_or_tokenizer_name}'" diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 38dfaaf00b..1c6dcb36d8 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -63,8 +63,6 @@ } LAYERS_MODULE_NAME = "model.layers" -warnings.filterwarnings("ignore", "You are using `torch.load` with `weights_only=False`*.") - def force_cache_dynamic_modules( config: PretrainedConfig, checkpoint_dir: Path | str, trust_remote_code: bool = False diff --git a/modelopt/torch/puzzletron/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py index e4b87e3770..02a69512ba 100644 --- a/modelopt/torch/puzzletron/tools/logger.py +++ b/modelopt/torch/puzzletron/tools/logger.py @@ -71,7 +71,7 @@ def dist_log(self, msg: str, ranks: str = "main"): # Only main rank at node 0 to print elif ( (ranks == "main" and self.global_rank != 0) - or (ranks == "last" and self.local_rank != self.world_size - 1) + or (ranks == "last" and self.global_rank != self.world_size - 1) or (ranks == "local_main" and self.local_rank != 0) ): return diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index 6bf966a2ae..a07590452f 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -191,7 +191,11 @@ def validate_puzzle_solutions(args: DictConfig) -> None: pass save_checkpoint(model, checkpoint_dir, descriptor) - copy_tokenizer(args.tokenizer_name, checkpoint_dir) + copy_tokenizer( + args.tokenizer_name, + checkpoint_dir, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) dist.barrier() From 2a170b994fbd666dedfea2b51574cecd17badfad Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 30 Mar 2026 08:58:25 -0700 Subject: [PATCH 58/62] Set weights_only=True in checkpoint_utils.py Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../torch/prune/importance_hooks/base_hooks.py | 2 +- .../importance_hooks/compare_module_outputs.py | 16 +++++++++++----- .../torch/puzzletron/tools/checkpoint_utils.py | 4 ++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/prune/importance_hooks/base_hooks.py b/modelopt/torch/prune/importance_hooks/base_hooks.py index a28908d4b6..44eea3bdbe 100644 --- a/modelopt/torch/prune/importance_hooks/base_hooks.py +++ b/modelopt/torch/prune/importance_hooks/base_hooks.py @@ -735,7 +735,7 @@ def _save_channel_importance_results( all_scores = [] for activation_file in activation_files: print(f"Loading activations from {activation_file}") - # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # Security NOTE: weights_only=False is required because files contain dictionaries with tensors. # These files are generated by dump_activations_logs() in this module and contain # hook state dictionaries. The activations_log_dir should only contain trusted files # generated by the same codebase, not from untrusted sources. diff --git a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py index e692a518ae..0f4c954e31 100644 --- a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py +++ b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py @@ -52,7 +52,8 @@ python compare_module_outputs.py \ --reference output_unpruned.pt \ --compare output_l2norm.pt \ - --output-json comparison_stats.json + --output-json comparison_stats.json \ + --trust-inputs The saved file format\: @@ -180,21 +181,26 @@ def main(): default=None, help="Path to save comparison statistics as JSON", ) + parser.add_argument( + "--trust-inputs", + action="store_true", + help="Trust input files for loading with weights_only=False in torch.load()", + ) args = parser.parse_args() # Load reference data print(f"\nLoading reference: {args.reference}") - # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # Security NOTE: weights_only=False is required because files contain dictionaries with tensors. # These files are expected to be generated by save_multi_layer_outputs() in this module, # not from untrusted sources. Users should only load files they generated themselves. - ref_data = torch.load(args.reference, map_location="cpu", weights_only=False) + ref_data = torch.load(args.reference, map_location="cpu", weights_only=args.trust_inputs) # Load comparison data print(f"Loading compare: {args.compare}") - # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # Security NOTE: weights_only=False is required because files contain dictionaries with tensors. # These files are expected to be generated by save_multi_layer_outputs() in this module, # not from untrusted sources. Users should only load files they generated themselves. - comp_data = torch.load(args.compare, map_location="cpu", weights_only=False) + comp_data = torch.load(args.compare, map_location="cpu", weights_only=args.trust_inputs) # Compare multi-layer outputs compare_multi_layer(ref_data, comp_data, args.output_json) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py index 8b14bd8027..26d640fc31 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py @@ -40,7 +40,7 @@ def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: checkpoint_dir = _normalize_checkpoint_dir(checkpoint_dir) if (state_dict_path := checkpoint_dir / STATE_DICT_FILE_NAME).exists(): - return torch.load(state_dict_path, map_location="cpu", weights_only=False) + return torch.load(state_dict_path, map_location="cpu", weights_only=True) if (safetensors_subblocks_dir := checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): return _load_state_dict_from_subblocks(safetensors_subblocks_dir) @@ -74,7 +74,7 @@ def _load_state_dict_from_subblocks(subblocks_dir: Path) -> dict[str, torch.Tens safetensors_paths = list(subblocks_dir.glob("*.safetensors")) if len(torch_paths) != 0: - load_fn = partial(torch.load, map_location="cpu", weights_only=False) + load_fn = partial(torch.load, map_location="cpu", weights_only=True) file_paths = torch_paths elif len(safetensors_paths) != 0: load_fn = safe_load_file From d6f8ddb0a4c8d5da6841a8c6abbfd0f98e27fe9b Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:54:47 -0700 Subject: [PATCH 59/62] More fixes Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../compare_module_outputs.py | 4 +-- modelopt/torch/puzzletron/mip/run_puzzle.py | 2 +- .../puzzletron/tools/checkpoint_utils.py | 2 +- modelopt/torch/puzzletron/tools/logger.py | 2 +- ...validate_puzzle_with_multi_replacements.py | 14 +++++--- .../mbridge_distillation/test_distill_hf.py | 5 --- tests/gpu/torch/puzzletron/test_puzzletron.py | 20 ++++++----- .../torch/puzzletron/test_convert_anymodel.py | 35 +++++++++++++++++++ tox.ini | 4 +++ 9 files changed, 65 insertions(+), 23 deletions(-) create mode 100644 tests/unit/torch/puzzletron/test_convert_anymodel.py diff --git a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py index 0f4c954e31..dbb4f564d7 100644 --- a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py +++ b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py @@ -193,14 +193,14 @@ def main(): # Security NOTE: weights_only=False is required because files contain dictionaries with tensors. # These files are expected to be generated by save_multi_layer_outputs() in this module, # not from untrusted sources. Users should only load files they generated themselves. - ref_data = torch.load(args.reference, map_location="cpu", weights_only=args.trust_inputs) + ref_data = torch.load(args.reference, map_location="cpu", weights_only=not args.trust_inputs) # Load comparison data print(f"Loading compare: {args.compare}") # Security NOTE: weights_only=False is required because files contain dictionaries with tensors. # These files are expected to be generated by save_multi_layer_outputs() in this module, # not from untrusted sources. Users should only load files they generated themselves. - comp_data = torch.load(args.compare, map_location="cpu", weights_only=args.trust_inputs) + comp_data = torch.load(args.compare, map_location="cpu", weights_only=not args.trust_inputs) # Compare multi-layer outputs compare_multi_layer(ref_data, comp_data, args.output_json) diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py index dd50003d1b..803fd83db3 100644 --- a/modelopt/torch/puzzletron/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -407,7 +407,7 @@ def _assert_valid_config(args, puzzle_profile): # Make sure we have specified subblock_stats_args if "subblock_stats_args" not in args and "subblock_stats_args" not in puzzle_profile: mprint( - "error: Must specify `subblock_stats_arrs` in either puzzle_profile or as a commandline arg." + "error: Must specify `subblock_stats_args` in either puzzle_profile or as a commandline arg." ) sys.exit(1) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py index 26d640fc31..4488898e33 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py @@ -145,7 +145,7 @@ def is_valid_decilm_checkpoint(checkpoint_dir: Path | str, trust_remote_code: bo """ try: model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) - if model_config.block_configs is None: + if not hasattr(model_config, "block_configs") or model_config.block_configs is None: warnings.warn( f"Skipping checkpoint '{checkpoint_dir}' - missing block_configs (not an AnyModel-style layout)" ) diff --git a/modelopt/torch/puzzletron/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py index 02a69512ba..257e55abe3 100644 --- a/modelopt/torch/puzzletron/tools/logger.py +++ b/modelopt/torch/puzzletron/tools/logger.py @@ -62,7 +62,7 @@ def dist_log(self, msg: str, ranks: str = "main"): if ranks not in ["all", "main", "local_main", "last"]: raise NotImplementedError( f"Could not broadcast msg {msg} - " - f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}" + f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}" ) # All ranks to print if ranks == "all": diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index a07590452f..f647cd3f89 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -124,7 +124,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None: args.solutions_to_validate = list(range(len(puzzle_solutions))) puzzle_solutions = [puzzle_solutions[i] for i in args.solutions_to_validate] - tokenizer = _load_tokenizer(args) + tokenizer = _load_tokenizer(args, trust_remote_code=descriptor.requires_trust_remote_code()) if not args.skip_validation: val_dataloader = ( validate_model.prepare_dataloader(args, tokenizer) if dist.is_master() else None @@ -231,14 +231,18 @@ def can_realize_as_symlinks(layer_replacements: list[dict]) -> bool: return True -def _load_tokenizer(args: DictConfig) -> PreTrainedTokenizerBase: +def _load_tokenizer(args: DictConfig, trust_remote_code: bool = False) -> PreTrainedTokenizerBase: tokenizer = None if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code + ) elif args.teacher_dir is not None: try: - tokenizer = AutoTokenizer.from_pretrained(args.teacher_dir, trust_remote_code=True) - except: + tokenizer = AutoTokenizer.from_pretrained( + args.teacher_dir, trust_remote_code=trust_remote_code + ) + except Exception: pass if tokenizer is None: warnings.warn("Couldn't find a tokenizer, trying to continue without one") diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py index db886f5050..cbc042c5e7 100644 --- a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -87,11 +87,6 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): model = AutoModelForCausalLM.from_pretrained(hf_export_dir) assert model is not None, "Failed to load distilled model with AutoModelForCausalLM" - print( - f"PYTEST SUMMARY: test_distill_hf test has finished successfully. " - f"Output directory: {output_dir}, HF export: {hf_export_dir}" - ) - def _prepare_student_and_teacher_models( project_root_path: Path, tmp_path: Path diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index 88635a6aaf..fa73d15b6c 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json +import warnings from datetime import timedelta from functools import partial from pathlib import Path @@ -171,11 +171,6 @@ def _test_puzzletron_multiprocess_job( dist.cleanup() - print( - f"PYTEST SUMMARY: test_puzzletron({hf_model_name}) test has finished successfully. " - f"Puzzle directory: {puzzle_dir}" - ) - def _assert_subblock_stats_anymodel(hf_model_name: str, hydra_cfg) -> None: """Minimal subblock_stats checks and teacher memory / param regression values.""" @@ -205,6 +200,10 @@ def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): expected = EXPECTED_PRUNING_VALUES[hf_model_name] size = dist.size() + if hf_model_name == "mistralai/Mistral-Small-24B-Instruct-2501" and size == 1: + warnings.warn("Mistral-Small score assertions only work for 2 GPUs") + return + if expected is not None: # In multi-GPU: layers are distributed across ranks # Each rank processes len(expected) // size layers @@ -217,13 +216,17 @@ def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): layer_data = pruning_scores[layer_name] # Calculate global layer index from rank and local index global_idx = rank * expected_layers_per_rank + i - assert layer_data["score"][0].item() == expected[global_idx]["score"] + assert layer_data["score"][0].item() == expected[global_idx]["score"], ( + layer_name, + layer_data["score"][0].item(), + expected[global_idx]["score"], + global_idx, + ) assert ( layer_data["channels_importance_ascending"][0].item() == expected[global_idx]["channels"] ) else: - # Print values for new models - update EXPECTED_PRUNING_VALUES with these print(f"\n=== PRUNING VALUES for {hf_model_name} (num_layers={len(layer_names)}) ===") print(f'"{hf_model_name}": [') for layer_name in layer_names: @@ -233,6 +236,7 @@ def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): print(f' {{"score": {score}, "channels": {channels}}},') print("],") print("===") + pytest.fail(f"Expected pruning values not found for {hf_model_name}") def _assert_lm_loss(puzzle_dir: Path, hf_model_name: str, tolerance: float = 0.01): diff --git a/tests/unit/torch/puzzletron/test_convert_anymodel.py b/tests/unit/torch/puzzletron/test_convert_anymodel.py new file mode 100644 index 0000000000..f27cb9c9b9 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_convert_anymodel.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("transformers") + +from _test_utils.torch.transformers_models import create_tiny_qwen3_dir +from transformers import AutoModelForCausalLM + +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + +def test_convert_anymodel(tmp_path): + input_dir = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True) + output_dir = tmp_path / "qwen3-0.6b-anymodel" + convert_model(input_dir, output_dir, converter="qwen3") + + descriptor = ModelDescriptorFactory.get("qwen3") + with deci_x_patcher(descriptor): + _ = AutoModelForCausalLM.from_pretrained(output_dir) diff --git a/tox.ini b/tox.ini index 80299d814d..7948a19f5c 100644 --- a/tox.ini +++ b/tox.ini @@ -66,6 +66,10 @@ commands_pre = # Install cupy-cuda13x for INT4 ONNX quantization (default is cupy-cuda12x) pip uninstall -y cupy-cuda12x pip install cupy-cuda13x + + # Install mamba and causal-conv1d for Nemotron tests + pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git + pip install --no-build-isolation git+https://github.com/Dao-AILab/causal-conv1d.git commands = # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" python -m pytest tests/gpu From 4621b65dad6e06d1ac94cd84a2eea5967bdf98f6 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 30 Mar 2026 11:26:10 -0700 Subject: [PATCH 60/62] reuse puzzletron tokenizer in other tests Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- tests/_test_utils/torch/puzzletron/utils.py | 68 +++++-------------- .../torch}/tokenizer/special_tokens_map.json | 0 .../torch}/tokenizer/tokenizer.json | 0 .../torch}/tokenizer/tokenizer_config.json | 0 .../_test_utils/torch/transformers_models.py | 14 ++-- tests/conftest.py | 2 +- .../mbridge_distillation/test_distill_hf.py | 7 +- tests/gpu/torch/puzzletron/conftest.py | 24 ------- .../nas/plugins/test_nas_convert.py | 4 +- .../puzzletron/nas/plugins/test_nas_search.py | 2 +- .../resources/tokenizer/truncate_tokenizer.py | 62 ----------------- tests/gpu/torch/puzzletron/test_puzzletron.py | 2 +- 12 files changed, 33 insertions(+), 152 deletions(-) rename tests/{gpu/torch/puzzletron/resources => _test_utils/torch}/tokenizer/special_tokens_map.json (100%) rename tests/{gpu/torch/puzzletron/resources => _test_utils/torch}/tokenizer/tokenizer.json (100%) rename tests/{gpu/torch/puzzletron/resources => _test_utils/torch}/tokenizer/tokenizer_config.json (100%) delete mode 100644 tests/gpu/torch/puzzletron/conftest.py delete mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index b5e32566de..fc6d6d5c16 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -14,39 +14,32 @@ # limitations under the License. import os -import shutil from pathlib import Path import torch +from _test_utils.torch.transformers_models import get_tiny_tokenizer from datasets import Dataset, DatasetDict -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers def setup_test_model_and_data( - project_root_path: Path, - tmp_path: Path, - rank: int, - hf_model_name: str, - hybrid_override_pattern: str | None = None, + tmp_path: Path, rank: int, hf_model_name: str, hybrid_override_pattern: str | None = None ) -> tuple[Path, Path, Path]: """ Setup the test model and data for the puzzletron NAS search. Args: - project_root_path (Path): the root path of the project - tmp_path (Path): the temporary path to use for the test - rank (int): the rank of the process - hf_model_name (str): HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct") - hybrid_override_pattern (str): For NemotronH models, the layer type pattern + tmp_path: the temporary path to use for the test + rank: the rank of the process + hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct") + hybrid_override_pattern: For NemotronH models, the layer type pattern Returns: - tuple[Path, Path, Path]: - the puzzle_dir, hf_checkpoint_path, dataset_path + tuple[Path, Path, Path]: the puzzle_dir, hf_checkpoint_path, dataset_path """ - # Register Hydra custom resolvers (needed for config resolution) register_hydra_resolvers() @@ -55,31 +48,23 @@ def setup_test_model_and_data( dataset_path = puzzle_dir / "dummy_dataset" if rank == 0: - # Setup puzzle_dir and dataset - setup_puzzle_dir(puzzle_dir) save_dummy_dataset(dataset_path) # Create a small HF model - tokenizer = create_tokenizer(project_root_path) + tokenizer = get_tiny_tokenizer() create_and_save_small_hf_model( output_path=str(hf_checkpoint_path), - vocab_size=tokenizer.vocab_size, tokenizer=tokenizer, hf_model_name=hf_model_name, hybrid_override_pattern=hybrid_override_pattern, ) dist.barrier() - return ( - puzzle_dir, - hf_checkpoint_path, - dataset_path, - ) + return puzzle_dir, hf_checkpoint_path, dataset_path def create_and_save_small_hf_model( output_path: str, - vocab_size: int, tokenizer: PreTrainedTokenizerBase, hf_model_name: str, hybrid_override_pattern: str | None = None, @@ -91,14 +76,11 @@ def create_and_save_small_hf_model( Args: output_path: Where to save the model - vocab_size: Vocabulary size (should match tokenizer) tokenizer: Tokenizer to save alongside the model hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct") hybrid_override_pattern: For NemotronH models, the layer type pattern (e.g., "*-" for Attention+MLP, "M-" for Mamba+MLP). Must match num_hidden_layers. None for non-NemotronH models. """ - os.makedirs(output_path, exist_ok=True) - # Load real HuggingFace config (preserves tie_word_embeddings, rope_scaling, etc.) config = AutoConfig.from_pretrained(hf_model_name, trust_remote_code=True) @@ -108,7 +90,7 @@ def create_and_save_small_hf_model( # VL models have nested configs (text_config, vision_config) if hasattr(config, "text_config") and hasattr(config, "vision_config"): - config.text_config.vocab_size = vocab_size + config.text_config.vocab_size = tokenizer.vocab_size config.text_config.hidden_size = 256 config.text_config.intermediate_size = 512 config.text_config.num_hidden_layers = 2 @@ -126,7 +108,7 @@ def create_and_save_small_hf_model( config.num_hidden_layers = config.text_config.num_hidden_layers else: # Regular models have flat config - config.vocab_size = vocab_size + config.vocab_size = tokenizer.vocab_size config.hidden_size = 256 config.intermediate_size = 512 config.num_hidden_layers = max(2, dist.size()) @@ -147,7 +129,10 @@ def create_and_save_small_hf_model( config.hybrid_override_pattern = hybrid_override_pattern # Ensure pad_token_id is within vocab_size (nn.Embedding requires padding_idx < num_embeddings) - if getattr(config, "pad_token_id", None) is not None and config.pad_token_id >= vocab_size: + if ( + getattr(config, "pad_token_id", None) is not None + and config.pad_token_id >= tokenizer.vocab_size + ): config.pad_token_id = 0 # Set seed for reproducible weight initialization @@ -171,7 +156,7 @@ def create_and_save_small_hf_model( model.initialize_weights() # Fix any remaining NaN/Inf values that initialize_weights() might have missed - for name, param in model.named_parameters(): + for param in model.parameters(): if torch.isnan(param).any() or torch.isinf(param).any(): nan_inf_mask = torch.isnan(param) | torch.isinf(param) param.data = torch.where(nan_inf_mask, torch.zeros_like(param), param) @@ -191,25 +176,6 @@ def create_and_save_small_hf_model( config.save_pretrained(output_path) -def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: - """ - Create a tokenizer for the model. - """ - tokenizer_path = project_root_path / "tests/gpu/torch/puzzletron/resources/tokenizer" - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - return tokenizer - - -def setup_puzzle_dir(puzzle_dir: str | Path): - """ - Setup puzzle directory by removing existing directory and creating a new one. - """ - puzzle_dir = Path(puzzle_dir) - if puzzle_dir.exists(): - shutil.rmtree(puzzle_dir) - puzzle_dir.mkdir(parents=True, exist_ok=True) - - def save_dummy_dataset(dataset_path: Path | str): """ Save a dummy dataset for testing purposes. diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json b/tests/_test_utils/torch/tokenizer/special_tokens_map.json similarity index 100% rename from tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json rename to tests/_test_utils/torch/tokenizer/special_tokens_map.json diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json b/tests/_test_utils/torch/tokenizer/tokenizer.json similarity index 100% rename from tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json rename to tests/_test_utils/torch/tokenizer/tokenizer.json diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json b/tests/_test_utils/torch/tokenizer/tokenizer_config.json similarity index 100% rename from tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json rename to tests/_test_utils/torch/tokenizer/tokenizer_config.json diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 01e8fa4d38..54bc10a562 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -39,6 +39,12 @@ SEED = 1234 +TINY_TOKENIZER_PATH = Path(__file__).parent / "tokenizer" + + +def get_tiny_tokenizer() -> "transformers.PreTrainedTokenizerBase": + return AutoTokenizer.from_pretrained(TINY_TOKENIZER_PATH) + ##### Qwen3 ##### def get_tiny_qwen3(**config_kwargs) -> PreTrainedModel: @@ -66,9 +72,7 @@ def create_tiny_qwen3_dir( ) -> Path | tuple[Path, PreTrainedModel]: qwen3_dir = Path(tmp_path) / "tiny_qwen3" if with_tokenizer: - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM" - ) + tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(qwen3_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size tiny_qwen3 = get_tiny_qwen3(**config_kwargs) @@ -149,9 +153,7 @@ def create_tiny_llama_dir( ) -> Path: llama_dir = Path(tmp_path) / "tiny_llama" if with_tokenizer: - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM" - ) + tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(llama_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size diff --git a/tests/conftest.py b/tests/conftest.py index 53a2330c22..a4e65ff2ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,7 +115,7 @@ def enable_hf_checkpointing(): mto.enable_huggingface_checkpointing() -@pytest.fixture +@pytest.fixture(scope="session") def project_root_path(request: pytest.FixtureRequest) -> Path: """Fixture providing the project root path for tests.""" return Path(request.config.rootpath) diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py index cbc042c5e7..6ca0ac0dd9 100644 --- a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -19,7 +19,8 @@ import torch from _test_utils.examples.run_command import extend_cmd_parts, run_example_command from _test_utils.torch.distributed.utils import get_free_port -from _test_utils.torch.puzzletron.utils import create_and_save_small_hf_model, create_tokenizer +from _test_utils.torch.puzzletron.utils import create_and_save_small_hf_model +from _test_utils.torch.transformers_models import get_tiny_tokenizer from transformers import AutoModelForCausalLM from modelopt.torch.puzzletron.anymodel import convert_model @@ -103,13 +104,12 @@ def _prepare_student_and_teacher_models( teacher_hf_dir = tmp_path / "teacher_hf" # Create tokenizer (uses local tokenizer from test resources) - tokenizer = create_tokenizer(project_root_path) + tokenizer = get_tiny_tokenizer() # Create student model using utility function (loads config from Hub). # TODO: Make the student model using different ffn sizes across layers. create_and_save_small_hf_model( output_path=str(student_hf_dir), - vocab_size=tokenizer.vocab_size, tokenizer=tokenizer, hf_model_name="Qwen/Qwen3-0.6B", hybrid_override_pattern=None, @@ -118,7 +118,6 @@ def _prepare_student_and_teacher_models( # Create teacher model (same as student for testing) create_and_save_small_hf_model( output_path=str(teacher_hf_dir), - vocab_size=tokenizer.vocab_size, tokenizer=tokenizer, hf_model_name="Qwen/Qwen3-0.6B", hybrid_override_pattern=None, diff --git a/tests/gpu/torch/puzzletron/conftest.py b/tests/gpu/torch/puzzletron/conftest.py deleted file mode 100644 index cae1bfbca5..0000000000 --- a/tests/gpu/torch/puzzletron/conftest.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path - -import pytest - - -@pytest.fixture -def project_root_path(request: pytest.FixtureRequest) -> Path: - """Fixture providing the project root path for tests.""" - return Path(request.config.rootpath) diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index 8a5bad0c62..25991f1c74 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -41,7 +41,7 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" ) hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" @@ -97,7 +97,7 @@ def _test_nas_convert_attn_pruning_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" ) hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning" diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index 2af371e5ca..aede36bded 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -40,7 +40,7 @@ def _test_nas_search_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" ) hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py b/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py deleted file mode 100644 index aedcae4ab2..0000000000 --- a/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script was used to truncate the tokenizer.json file from Llama 3.1 8B model -to keep only the top 100 most common tokens. -""" - -import json - -# Path to your original and new tokenizer.json -in_path = "./tokenizer.json" -out_path = "./tokenizer_truncated.json" - -# How many top tokens to keep -NUM_TO_KEEP = 100 - -with open(in_path, encoding="utf-8") as f: - tokenizer_data = json.load(f) - -# Get and sort the original vocab by index (frequency proxy) -orig_vocab = tokenizer_data["model"]["vocab"] - -# Sort tokens by their original index (lowest index = assumed most common/important) -sorted_tokens = sorted(orig_vocab.items(), key=lambda item: item[1]) - -# Keep the top N tokens -tokens_to_keep = [tok for tok, idx in sorted_tokens[:NUM_TO_KEEP]] - -# Re-index the selected tokens: 0..N-1 -small_vocab = {tok: i for i, tok in enumerate(tokens_to_keep)} -tokenizer_data["model"]["vocab"] = small_vocab - -# Update vocab size -if "vocab_size" in tokenizer_data["model"]: - tokenizer_data["model"]["vocab_size"] = len(small_vocab) - -# Optionally remove merges if present and unneeded (mostly for BPE/WordPiece) -if "merges" in tokenizer_data["model"]: - tokenizer_data["model"]["merges"] = [] - -# Remove added_tokens if not needed -if "added_tokens" in tokenizer_data: - tokenizer_data["added_tokens"] = [] - -# Write out the truncated tokenizer.json -with open(out_path, "w", encoding="utf-8") as f: - json.dump(tokenizer_data, f, indent=2, ensure_ascii=False) - -print(f"Truncated tokenizer saved to: {out_path}") diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index fa73d15b6c..45c438ec0d 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -94,7 +94,7 @@ def _test_puzzletron_multiprocess_job( # Setup the test model and data. puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, hf_model_name, hybrid_override_pattern + tmp_path, rank, hf_model_name, hybrid_override_pattern ) hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" model_basename = hf_model_name.split("/")[1] From be4bd3aa8b1fc7f5981d90ae61c298884193c909 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 30 Mar 2026 11:27:55 -0700 Subject: [PATCH 61/62] disable puzzletron in coverage check as its covered in gpu tests only Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e9d40f466a..a03e2029dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -292,7 +292,7 @@ markers = [ [tool.coverage.run] branch = false include = ["modelopt/*"] -omit = ["*/plugins/*", "*/export/*"] +omit = ["*/plugins/*", "*/export/*", "modelopt/torch/puzzletron/*"] [tool.coverage.report] From 45426ca8c1fa5614ab2ac7e2154b61a616a8d135 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:45:41 +0530 Subject: [PATCH 62/62] Remove custom DistillationProvider and simplify mbridge distillation and hf export (#1122) - In nemo:26.02.01 container, we have DistillationProvider fix for homogeneous models already. That seems sufficient for Heterogeneous models as well so removing copied DistillationProvider to simplify - Replace hacky megatron to hf export logic with simplified one ## Summary by CodeRabbit * **Refactor** * Reworked distillation and HuggingFace export flow to use upstream bridge/export APIs, removed local monkey-patching and extra exception logging, and simplified distributed cleanup. * **Chores** * Consolidated and renamed Qwen3 / Qwen3-VL model and converter registrations; updated pruning configs, CLI export flags, and packaging lint/dependency settings. * **Tests** * Updated integration tests to use Qwen3 checkpoints and adjusted export verification. * **Documentation** * Updated README example to reflect new CLI usage. --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../puzzletron/mbridge_distillation/README.md | 12 +- .../mbridge_distillation/distill_hf.py | 87 +++----- .../torch/puzzletron/export/mbridge/base.py | 15 +- .../export/mbridge/distillation_provider.py | 190 ------------------ .../export/mbridge/export_mbridge_to_hf.py | 81 -------- .../mbridge_distillation/test_distill_hf.py | 8 +- 6 files changed, 52 insertions(+), 341 deletions(-) delete mode 100644 modelopt/torch/puzzletron/export/mbridge/distillation_provider.py delete mode 100644 modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md index f7dda866e8..9658e48ebc 100644 --- a/examples/puzzletron/mbridge_distillation/README.md +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -22,7 +22,7 @@ git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} **Start Docker container:** -Use the [NeMo 26.02.01 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02.01): +Use the [NeMo 26.02 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02): ```bash # Recommended to mount a workspace directory for storing datasets and distilled models @@ -31,7 +31,7 @@ docker run --gpus all -it --rm \ -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ -w /opt/Model-Optimizer \ - nvcr.io/nvidia/nemo:26.02.01 \ + nvcr.io/nvidia/nemo:26.02 \ /bin/bash ``` @@ -66,12 +66,12 @@ Run distillation directly from HuggingFace checkpoints (student and teacher) wit ```bash torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.py \ - --student_hf_path /path/to/student/huggingface/checkpoint \ + --student_hf_path /path/to/student/puzzletron/checkpoint \ + --student_hf_model meta-llama/Llama-3.1-8B-Instruct \ --teacher_hf_path /path/to/teacher/huggingface/checkpoint \ --data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \ --output_dir /path/to/distilled/checkpoint \ - --hf-export-path /path/to/exported/hf/model \ - --hf-model meta-llama/Llama-3.1-8B-Instruct \ + --hf_export_path /path/to/exported/hf/model \ --seq_length 4096 \ --tp_size 8 \ --pp_size 1 \ @@ -90,7 +90,7 @@ torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf. - Add `--trust_remote_code` if student or teacher checkpoints need HuggingFace custom modeling code. - The distilled Megatron-Bridge checkpoint will be saved to `--output_dir/checkpoints/iter_`. -- Add `--hf-export-path` (or `--hf_export_path`) to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--hf-model` / `--hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation). +- Add `--hf_export_path` to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--student_hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation). - For production use, use larger datasets like [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) and train for more iterations. See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices. ## MMLU Evaluation Results diff --git a/examples/puzzletron/mbridge_distillation/distill_hf.py b/examples/puzzletron/mbridge_distillation/distill_hf.py index ac703909c2..30fff0ca14 100644 --- a/examples/puzzletron/mbridge_distillation/distill_hf.py +++ b/examples/puzzletron/mbridge_distillation/distill_hf.py @@ -22,11 +22,11 @@ import argparse import os -import traceback +import shutil -import megatron.bridge.models.distillation_provider import torch from megatron.bridge import AutoBridge +from megatron.bridge.models.distillation_provider import convert_to_distillation_provider from megatron.bridge.recipes.utils.optimizer_utils import ( distributed_fused_adam_with_cosine_annealing, ) @@ -40,39 +40,16 @@ TokenizerConfig, TrainingConfig, ) +from megatron.bridge.training.distill import distill from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig from megatron.core.datasets.utils import get_blend_from_list from megatron.core.distributed import DistributedDataParallelConfig -# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure -# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers -# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge -# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration. -# -# Note: Currently, bridges are also registered when distillation_provider is imported -# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider -# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron. +# Import to register heterogeneous bridges (side effect) import modelopt.torch.puzzletron.export.mbridge # noqa: F401 import modelopt.torch.utils.distributed as dist - -# Use local copy of distillation_provider with fix for heterogeneous models -# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge -from modelopt.torch.puzzletron.export.mbridge.distillation_provider import ( - DistillationProvider, - convert_to_distillation_provider, -) -from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import ( - export_to_hf_and_copy_config, -) from modelopt.torch.utils import print_rank_0 -# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider -# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time -megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider - -# Import distill() AFTER patching so it uses the patched DistillationProvider -from megatron.bridge.training.distill import distill # noqa: E402 - SEED = 1234 @@ -84,13 +61,13 @@ def get_args(): "--student_hf_path", type=str, required=True, - help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)", + help="HuggingFace model name or path for the student (standard HF format or puzzletron any_model format)", ) parser.add_argument( "--teacher_hf_path", type=str, required=True, - help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)", + help="HuggingFace model name or path for the teacher (standard HF format or puzzletron any_model format)", ) parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code") # Parallelism arguments @@ -145,21 +122,20 @@ def get_args(): # Export arguments parser.add_argument( "--hf_export_path", - "--hf-export-path", type=str, default=None, help=( "Path where to save the HuggingFace export. " - "If provided, exports checkpoint to HF format after distillation." + "If provided, exports last iteration checkpoint to HF format after distillation." ), ) parser.add_argument( - "--hf_model", - "--hf-model", + "--student_hf_model", type=str, - required=True, - help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). " - "Should match the base architecture of the student model.", + required=False, + default=None, + help="HuggingFace model ID to use as template for export (e.g., Qwen/Qwen3-0.6B). " + "Should match the base architecture of the student model if --hf_export_path is provided.", ) args = parser.parse_args() @@ -167,6 +143,9 @@ def get_args(): if not args.use_mock_data and not args.data_paths: raise ValueError("Must provide either --data_paths or set --use_mock_data.") + if args.hf_export_path and not args.student_hf_model: + raise ValueError("Must provide --student_hf_model if --hf_export_path is provided.") + print_rank_0("\n==================== Arguments ====================") for k, v in args.__dict__.items(): print_rank_0(f"{k:<35} {v}") @@ -288,32 +267,28 @@ def _build_model_provider(hf_path): # Export to HuggingFace format if hf_export_path is provided if args.hf_export_path: - # Wait for all ranks to finish distillation before export - if torch.distributed.is_initialized(): - torch.distributed.barrier() - + print_rank_0(f"Exporting final distilled ckpt to HF format to {args.hf_export_path}") # Save rank before destroying process group (dist.rank() won't work after destruction) is_rank_0 = dist.rank() == 0 # Destroy process group on all ranks - export_ckpt will create its own temporary one # This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone) - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() + dist.cleanup() # Only rank 0 exports if is_rank_0: - try: - export_to_hf_and_copy_config( - student_hf_path=args.student_hf_path, - checkpoint_dir=checkpoint_dir, - train_iters=args.train_iters, - hf_export_path=args.hf_export_path, - hf_model=args.hf_model, - trust_remote_code=args.trust_remote_code, - ) - except Exception as e: - print(f"⚠️ Export failed: {e}") - traceback.print_exc() + export_bridge = AutoBridge.from_hf_pretrained( + args.student_hf_model, trust_remote_code=args.trust_remote_code + ) + export_bridge.export_ckpt( + megatron_path=f"{checkpoint_dir}/iter_{args.train_iters:07d}", + hf_path=args.hf_export_path, + show_progress=True, + strict=True, + ) + + # save config from student_model to hf_export_path + shutil.copy(f"{args.student_hf_path}/config.json", f"{args.hf_export_path}/config.json") if __name__ == "__main__": @@ -321,9 +296,5 @@ def _build_model_provider(hf_path): args = get_args() try: main(args) - except Exception as e: - print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}") - print_rank_0(f"Traceback:\n{traceback.format_exc()}") - raise finally: dist.cleanup() diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py index 13ea6612af..4f01f800e6 100644 --- a/modelopt/torch/puzzletron/export/mbridge/base.py +++ b/modelopt/torch/puzzletron/export/mbridge/base.py @@ -28,12 +28,20 @@ from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM -from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig +from megatron.bridge.models.transformer_config import ( + HeterogeneousTransformerConfig, + TransformerConfig, +) from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( get_gpt_heterogeneous_layer_spec, ) from megatron.core.transformer.spec_utils import ModuleSpec +# Monkey-patch: add get_config_for_layer to TransformerConfig if missing +# (needed for non-heterogeneous teacher models in this container version) +if not hasattr(TransformerConfig, "get_config_for_layer"): + TransformerConfig.get_config_for_layer = lambda self, layer_number: self + def heterogeneous_layer_spec(config) -> ModuleSpec: """Get GPT heterogeneous layer spec using Transformer Engine.""" @@ -87,9 +95,12 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all the fields that the parent bridge sets. """ - parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] + # If no block_configs, fall back to standard (non-heterogeneous) provider. + if not (hasattr(hf_pretrained.config, "block_configs")): + return parent_provider + provider_kwargs = dataclasses.asdict(parent_provider) # Filter to only fields that GenericHeterogeneousProvider accepts. diff --git a/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py deleted file mode 100644 index fa49dc29c5..0000000000 --- a/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py +++ /dev/null @@ -1,190 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO: Upstream this fix to Megatron-Bridge and remove this local copy. - -import logging -from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, Any, Optional - -from megatron.bridge.models.gpt_provider import GPTModelProvider -from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider -from megatron.bridge.models.transformer_config import TransformerConfig -from megatron.core.models.gpt import GPTModel as MCoreGPTModel - -import modelopt.torch.distill as mtd -import modelopt.torch.distill.plugins.megatron as mtd_mcore - -if TYPE_CHECKING: - from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig - - -logger = logging.getLogger(__name__) - - -@dataclass -class DistillationProvider(TransformerConfig): - """Provider for Megatron Core GPT models in distillation mode. - - Please use `convert_to_distillation_provider()` to create an instance of this class. - """ - - teacher: Optional[GPTModelProvider | MambaModelProvider] = None - kd_config: Optional["ModelOptDistillConfig"] = None - - def __init__(self, *args, **kwargs): - raise NotImplementedError( - "Use `convert_to_distillation_provider()` to create an instance of this class." - ) - - def __post_init__(self): - assert getattr(self, "teacher", None) is not None, "Teacher model must be provided." - - shared_attrs = [ - "tensor_model_parallel_size", - "pipeline_model_parallel_size", - "context_parallel_size", - "seq_length", - "pipeline_dtype", - ] - for attr in shared_attrs: - if getattr(self, attr) != getattr(self.teacher, attr): - raise ValueError(f"Student and teacher providers must have the same {attr}.") - - # Logits are overwritten in-place when TE cross-entropy loss is enabled, so switch it back to native version. - self.cross_entropy_fusion_impl = "native" - - # Hack to dynamically subclass other providers and still use their methods - self._super_class = self.__class__.__bases__[0] - - def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: - """Configure and instantiate a ModelOpt DistillationModel based on this configuration. - - Args: - pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage - post_process: Whether to include post-processing in the model, defaults to last pipeline stage - vp_stage: Virtual pipeline stage - - Returns: - MCoreGPTModel: Configured ModelOpt DistillationModel instance - """ - if vp_stage is not None: - raise ValueError("ModelOpt KD currently does not support virtual-pipeline parallel.") - - assert self.teacher is not None, "Teacher model must be provided." - student_model = self._super_class.provide(self, pre_process, post_process, vp_stage) # type: ignore[attr-defined] - - # Finalize teacher provider before creating model (required for heterogeneous models). - # - # per_block_parameters is an attribute of HeterogeneousTransformerConfig (defined in - # MCoreHeterogeneousTransformerConfig, heterogeneous_config.py:197). It's created during - # provider creation (bridge.to_megatron_provider()), but finalize() ensures they're consistent - # with current parallelism settings and distributed context. Student model creation (above) - # initializes parallel_state (process groups, TP/PP config), which weight loading/scatter - # requires. During teacher model creation, get_config_for_layer() is called (transformer_block.py:341) - # for each layer, which uses per_block_parameters and current tensor_model_parallel_size to - # determine layer architecture. Without finalize() in this context, architecture expectations - # don't match checkpoint weights, causing: - # ValueError: ProcessGroupNCCL::scatter: invalid tensor size at index 0 - # (expected (2880, 4096), got (3584, 4096)) - # - # Note: This explanation needs to be confirmed yet. - self.teacher.finalize() - - # Hack to get teacher's pre-wrap hooks called to potentially load HF weights - teacher_model = self.teacher.provide_distributed_model( - wrap_with_ddp=False, mixed_precision_wrapper=None - )[0] - - kd_cfg = mtd_mcore.setup_distillation_config( - self.kd_config, student_model.config, teacher_model.config - ) - modelopt_cfg = { - "teacher_model": teacher_model, - "criterion": kd_cfg.criterion, - "loss_balancer": kd_cfg.loss_balancer, - } - kd_model = mtd.convert(student_model, mode=[("kd_loss", modelopt_cfg)]) - mtd_mcore.adjust_distillation_model_for_mcore(kd_model, kd_cfg) - - return kd_model - - def to_cfg_dict(self) -> dict[str, Any]: - """Custom method to save equivalent to the original provider class. - - Used by `_ConfigContainerBase` to serialize the main `ConfigContainer` to YAML. - There is no need to restore a `DistillationProvider` from the run config file, as - it can always be re-converted using the original student provider. - - Returns: - Dictionary representation of this provider class - """ - from megatron.bridge.training.utils.config_utils import _ConfigContainerBase - - result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} - - # Include all fields from the original provider class (self._super_class), not just DistillationProvider - # This ensures fields like heterogeneous_layers_config_encoded_json are preserved - excluded_fields = {"teacher", "kd_config"} - for field in fields(self._super_class): - if field.name.startswith("_") or field.name in excluded_fields: - continue - # Only include if the field exists on this instance (it should, since we converted from the original provider) - if hasattr(self, field.name): - result[field.name] = _ConfigContainerBase._convert_value_to_dict( - getattr(self, field.name) - ) - - # Also include any additional fields from DistillationProvider itself (if any) - for field in fields(self): - if field.name.startswith("_") or field.name in excluded_fields: - continue - # Skip if already included from _super_class - if field.name not in result: - result[field.name] = _ConfigContainerBase._convert_value_to_dict( - getattr(self, field.name) - ) - - return result - - def __setattr__(self, name, value): - super().__setattr__(name, value) - # Mirror to teacher if it has that attribute - if hasattr(self.teacher, name): - setattr(self.teacher, name, value) - - -def convert_to_distillation_provider( - student_provider: GPTModelProvider | MambaModelProvider, - teacher_provider: GPTModelProvider | MambaModelProvider, - kd_config: Optional["ModelOptDistillConfig"] = None, -) -> "DistillationProvider": - """Convert a given model provider to a DistillationProvider.""" - - assert isinstance(student_provider, (GPTModelProvider, MambaModelProvider)), ( - "Student provider must be a subclass of GPTModelProvider or MambaModelProvider." - ) - assert isinstance(teacher_provider, (GPTModelProvider, MambaModelProvider)), ( - "Teacher provider must be a subclass of GPTModelProvider or MambaModelProvider." - ) - - DistillationProvider.__bases__ = (type(student_provider),) - student_provider.__class__ = DistillationProvider - - student_provider.teacher = teacher_provider - student_provider.kd_config = kd_config - student_provider.__post_init__() - - return student_provider diff --git a/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py deleted file mode 100644 index 0ab6083f77..0000000000 --- a/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Export utilities for Megatron-Bridge checkpoints.""" - -import shutil -from pathlib import Path - -from megatron.bridge import AutoBridge - -from modelopt.torch.utils import print_rank_0 - - -def export_to_hf_and_copy_config( - student_hf_path: str, - checkpoint_dir: str, - train_iters: int, - hf_export_path: str, - hf_model: str, - trust_remote_code: bool = False, -) -> None: - """ - Export Megatron checkpoint to HuggingFace format and copy config.json from student model. - - TODO: This script should not be needed (manually copying config.json from - student model to exported model). Remove it once export_to_hf() in AutoBridge - supports copying/preserving config.json from student model. - - - Args: - student_hf_path: Path to the original student HuggingFace model (source of config.json) - checkpoint_dir: Base directory where Megatron checkpoints are stored - train_iters: Number of training iterations (used to construct final checkpoint path) - hf_export_path: Directory path where the HuggingFace model will be saved - hf_model: HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct) - trust_remote_code: Whether to trust remote modeling code when loading the HF template model - """ - print_rank_0(f"\n{'=' * 80}") - print_rank_0("Exporting to HuggingFace format...") - print_rank_0(f"{'=' * 80}\n") - - # Construct path to final checkpoint iteration (format: iter_0000100 for 100 iterations) - final_iter_dir = Path(checkpoint_dir) / f"iter_{train_iters:07d}" - print_rank_0(f"📂 Using final checkpoint: {final_iter_dir}") - - # Use the final iteration directory for export (export_ckpt will validate it exists) - megatron_path = str(final_iter_dir) - - # Create bridge using standard model ID (not AnyModel checkpoint) to avoid sharding structure issues - print_rank_0("🌉 Creating bridge...") - print_rank_0(f" Using model ID: {hf_model}") - bridge = AutoBridge.from_hf_pretrained(hf_model, trust_remote_code=trust_remote_code) - - print_rank_0("📤 Exporting to HuggingFace format...") - bridge.export_ckpt( - megatron_path=megatron_path, hf_path=hf_export_path, show_progress=True, strict=True - ) - - print_rank_0(f"✅ Successfully exported model to: {hf_export_path}") - - # Copy config.json from student model to exported model (preserves block_configs) - student_config_path = Path(student_hf_path) / "config.json" - exported_config_path = Path(hf_export_path) / "config.json" - - print_rank_0(f"📋 Copying config.json from student model: {student_config_path}") - shutil.copy(student_config_path, exported_config_path) - print_rank_0(f"✅ Copied config.json to: {exported_config_path}") - - print_rank_0(f"\n{'=' * 80}") - print_rank_0("Export complete!") diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py index 6ca0ac0dd9..b556805d71 100644 --- a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -34,8 +34,8 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): Models are converted to include block_configs. """ # Prepare student and teacher models - student_hf_dir, student_anymodel_dir, _, teacher_anymodel_dir = ( - _prepare_student_and_teacher_models(project_root_path, tmp_path) + student_hf_dir, student_anymodel_dir, teacher_hf_dir, _ = _prepare_student_and_teacher_models( + project_root_path, tmp_path ) output_dir = tmp_path / "distill_output" @@ -59,7 +59,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): extend_cmd_parts( cmd_parts, student_hf_path=student_anymodel_dir, - teacher_hf_path=teacher_anymodel_dir, + teacher_hf_path=teacher_hf_dir, output_dir=output_dir, tp_size=tp_size, pp_size=1, @@ -75,7 +75,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path): eval_iters=0, log_interval=5, hf_export_path=hf_export_dir, - hf_model=student_hf_dir, + student_hf_model=student_hf_dir, ) run_example_command(cmd_parts, example_path="puzzletron/mbridge_distillation")