Skip to content

Commit 4899bb5

Browse files
add mip module
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent 0a3b7f2 commit 4899bb5

File tree

3 files changed

+919
-1
lines changed

3 files changed

+919
-1
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# mypy: ignore-errors
16+
from pathlib import Path
17+
from typing import List
18+
19+
import hydra
20+
import torch
21+
import torch.distributed as dist
22+
from omegaconf import DictConfig
23+
from utils.dist_utils import is_distributed
24+
25+
from modelopt.torch._compress.mip.run_puzzle import run_puzzle
26+
from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers
27+
from modelopt.torch._compress.tools.logger import mprint
28+
from modelopt.torch._compress.tools.runtime import BaseRuntime, IRuntime, NativeDdpRuntime
29+
from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import (
30+
validate_puzzle_solutions,
31+
)
32+
33+
34+
def launch_mip(cfg: DictConfig) -> List[str]:
35+
solution_paths = run_puzzle(args=cfg.mip)
36+
return solution_paths
37+
38+
39+
def launch_realize_model(cfg: DictConfig, runtime: IRuntime):
40+
validate_puzzle_solutions(args=cfg.realize_model, runtime=runtime)
41+
42+
43+
def launch_mip_and_realize_model(cfg: DictConfig, runtime: IRuntime):
44+
if runtime.is_main_process:
45+
solution_paths = launch_mip(cfg)
46+
length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long)
47+
else:
48+
solution_paths = None
49+
length_tensor = torch.tensor([0], dtype=torch.long)
50+
51+
if not cfg.skip_realize_model:
52+
if runtime.world_size > 1:
53+
dist.broadcast(length_tensor, src=0)
54+
55+
list_length = length_tensor.item()
56+
57+
if runtime.global_rank != 0:
58+
solution_paths = [None] * list_length
59+
60+
if runtime.world_size > 1:
61+
dist.broadcast_object_list(solution_paths, src=0)
62+
63+
for solution_path in solution_paths:
64+
mprint(f"Realize model for the solution: {solution_path}")
65+
cfg.realize_model.solutions_path = Path(solution_path)
66+
launch_realize_model(cfg, runtime=runtime)
67+
runtime.wait_for_everyone()
68+
69+
70+
@hydra.main("", version_base="1.3")
71+
def main(cfg: DictConfig) -> None:
72+
cfg = hydra.utils.instantiate(cfg)
73+
74+
_runtime = (
75+
NativeDDP_Runtime(
76+
dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes")
77+
)
78+
if is_distributed()
79+
else BaseRuntime(dtype=torch.bfloat16)
80+
)
81+
with _runtime as runtime:
82+
launch_mip_and_realize_model(cfg, runtime)
83+
84+
85+
if __name__ == "__main__":
86+
register_hydra_resolvers()
87+
main()

0 commit comments

Comments
 (0)