Skip to content

Commit d4dd0d7

Browse files
import refactoring
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent 22da6b0 commit d4dd0d7

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

modelopt/torch/_compress/subblock_stats/calc_subblock_memory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919

2020
import numpy as np
2121
import torch
22-
from puzzle_tools.deci_lm_hf_code.block_config import AttentionConfig, FFNConfig, MambaConfig
2322

23+
from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import (
24+
AttentionConfig,
25+
FFNConfig,
26+
MambaConfig,
27+
)
2428
from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig
2529
from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMMoe
2630
from modelopt.torch._compress.utils.utils import (

modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,15 @@
3434
import torch
3535
from frozendict import frozendict
3636
from omegaconf import DictConfig, ListConfig, OmegaConf
37-
from puzzle_tools.deci_lm_hf_code.block_config import (
37+
from tqdm import tqdm
38+
39+
from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import (
3840
AttentionConfig,
3941
BlockConfig,
4042
FFNConfig,
4143
SubblockConfig,
4244
)
43-
from puzzle_tools.replacement_utils import parse_layer_replacement
44-
from tqdm import tqdm
45-
from utils.parsing import format_global_config
46-
45+
from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement
4746
from modelopt.torch._compress.subblock_stats.calc_subblock_memory import (
4847
calc_subblock_active_params,
4948
calculate_non_block_memory,
@@ -55,6 +54,7 @@
5554
from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers
5655
from modelopt.torch._compress.tools.logger import mprint
5756
from modelopt.torch._compress.tools.robust_json import json_dump
57+
from modelopt.torch._compress.utils.parsing import format_global_config
5858

5959
# Type variable for dataclasses
6060
T_DataClass = TypeVar("T_DataClass")

0 commit comments

Comments
 (0)