Skip to content

Commit 0a3b7f2

Browse files
Add scoring module
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent d4dd0d7 commit 0a3b7f2

File tree

4 files changed

+538
-1
lines changed

4 files changed

+538
-1
lines changed

modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
from pathlib import Path
2222

2323
import mip_and_realize_models
24-
import scoring
2524
import torch
2625
from torch import nn
2726

2827
import modelopt.torch._compress.build_library_and_stats as build_library_and_stats
2928
import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts
29+
import modelopt.torch._compress.scoring.scoring as scoring
3030
from modelopt.torch._compress.activation_scoring import score_pruning_activations
3131
from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import (
3232
convert_llama3_to_decilm,
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# mypy: ignore-errors
16+
import os
17+
import re
18+
from glob import glob
19+
from pathlib import Path
20+
21+
import hydra
22+
import numpy as np
23+
import pandas as pd
24+
import torch
25+
from omegaconf import DictConfig
26+
27+
from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers
28+
from modelopt.torch._compress.tools.logger import mprint
29+
from modelopt.torch._compress.tools.runtime import BaseRuntime, IRuntime, NativeDdpRuntime
30+
from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import (
31+
validate_puzzle_solutions,
32+
)
33+
from modelopt.torch._compress.utils.dist_utils import is_distributed
34+
35+
36+
def extract_solution_id(filename):
37+
pattern = r"solution_(\d+)\.json"
38+
match = re.search(pattern, filename)
39+
40+
if match:
41+
solution_id = match.group(1)
42+
return int(solution_id)
43+
else:
44+
mprint(f"Couldn't extract solutions_id from file {filename}")
45+
46+
47+
def find_missing_solutions(solutions_df, validation_dir):
48+
all_solutions = np.arange(solutions_df.shape[0])
49+
50+
benchmarked_solutions = list(glob(f"{validation_dir}/solution*.json"))
51+
benchmarked_solutions = [
52+
extract_solution_id(os.path.basename(s)) for s in benchmarked_solutions
53+
]
54+
benchmarked_solutions = [s for s in benchmarked_solutions if s is not None]
55+
56+
unbenchmarked_solutions = np.setdiff1d(all_solutions, benchmarked_solutions)
57+
return unbenchmarked_solutions.tolist()
58+
59+
60+
def get_solutions_to_validate(cfg: DictConfig):
61+
_solutions_to_validate = cfg.scoring.solutions_to_validate
62+
if _solutions_to_validate is None:
63+
single_block_replacement_solutions = pd.read_json(cfg.scoring.solutions_path)
64+
if cfg.scoring.skip_existing_solutions:
65+
_solutions_to_validate = find_missing_solutions(
66+
single_block_replacement_solutions, cfg.scoring.output_dir
67+
)
68+
else:
69+
_solutions_to_validate = np.arange(single_block_replacement_solutions.shape[0]).tolist()
70+
return _solutions_to_validate
71+
72+
73+
def launch_scoring(cfg: DictConfig, runtime: IRuntime):
74+
cfg.scoring.solutions_to_validate = get_solutions_to_validate(cfg)
75+
mprint(f"Solutions to validate: {cfg.scoring.solutions_to_validate}")
76+
validate_puzzle_solutions(args=cfg.scoring, runtime=runtime)
77+
78+
79+
@hydra.main("", version_base="1.3")
80+
def main(cfg: DictConfig) -> None:
81+
cfg = hydra.utils.instantiate(cfg)
82+
mprint(cfg)
83+
84+
_runtime = (
85+
NativeDdpRuntime(
86+
dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes")
87+
)
88+
if is_distributed()
89+
else BaseRuntime(dtype=torch.bfloat16)
90+
)
91+
with _runtime as runtime:
92+
launch_scoring(cfg, runtime)
93+
94+
95+
if __name__ == "__main__":
96+
register_hydra_resolvers()
97+
main()

0 commit comments

Comments
 (0)