Skip to content

Commit 548e4dd

Browse files
author
Zhao Changmin
authored
LLM: Adapt transformers models for optimize model SL (intel#9022)
* LLM: Adapt transformers model for SL
1 parent f64257a commit 548e4dd

File tree

4 files changed

+286
-20
lines changed

4 files changed

+286
-20
lines changed

python/llm/src/bigdl/llm/optimize.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
from accelerate.utils import set_module_tensor_to_device
2525
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
2626
from bigdl.llm.utils.common import invalidInputError
27+
from bigdl.llm.transformers.utils import extract_local_archive_file, get_local_shard_files
28+
import transformers
29+
from transformers import PreTrainedModel
30+
from .utils.common import MuteHFLogger
31+
from .utils.lazy_load_torch import LazyLoadTensors
32+
from contextlib import ExitStack, contextmanager
2733

2834

2935
# Simulate the Hugging Face format
@@ -37,7 +43,14 @@ def _save_low_bit(self, save_dir, *args, **kwargs):
3743
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
3844
os.makedirs(save_dir, exist_ok=True)
3945
model_path = os.path.join(save_dir, PYTORCH_MODEL_NAME)
40-
torch.save(self.state_dict(), model_path, *args, **kwargs)
46+
if isinstance(self, PreTrainedModel):
47+
# We borrowed this method to adapt to Transformer model cases
48+
# as much as possible, and later we may merge these two situations
49+
self.save_pretrained(save_dir)
50+
else:
51+
# TODO: For the lowbit model still larger than 8GB,
52+
# save it into shards.
53+
torch.save(self.state_dict(), model_path, *args, **kwargs)
4154
with open(os.path.join(save_dir, CONFIG_NAME), "w") as json_file:
4255
json.dump(self._bigdl_config, json_file)
4356

@@ -49,14 +62,44 @@ class DisableTorchAllocTensor():
4962
def __init__(self) -> None:
5063
self._old_torch_load_state_dict = Module.load_state_dict
5164
self._old_torch_to_device = Module.to
65+
self._old_torch_load_from_state_dict = Module._load_from_state_dict
66+
# Chatglm2 init weights manually,
67+
# and `skip_init` init on `cpu` by default
68+
self._old_skip_init = torch.nn.utils.skip_init
5269

5370
def __enter__(self):
5471
Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], [])
72+
Module._load_from_state_dict = lambda *args, **kwargs: None
5573
Module.to = lambda self, *args, **kwargs: self
5674

75+
def skip_init_on_meta(module_cls, *args, **kwargs):
76+
kwargs['device'] = 'meta'
77+
return self._old_skip_init(module_cls, *args, **kwargs)
78+
torch.nn.utils.skip_init = skip_init_on_meta
79+
5780
def __exit__(self, exc_type, exc_value, traceback):
5881
Module.load_state_dict = self._old_torch_load_state_dict
82+
Module._load_from_state_dict = self._old_torch_load_from_state_dict
5983
Module.to = self._old_torch_to_device
84+
torch.nn.utils.skip_init = self._old_skip_init
85+
86+
87+
class ContextManagers:
88+
"""
89+
Wrapper for `contextlib.ExitStack` which enters a collection of context managers.
90+
Adaptation of `ContextManagers` in the `fastcore` library.
91+
"""
92+
93+
def __init__(self, context_managers):
94+
self.context_managers = context_managers
95+
self.stack = ExitStack()
96+
97+
def __enter__(self):
98+
for context_manager in self.context_managers:
99+
self.stack.enter_context(context_manager)
100+
101+
def __exit__(self, *args, **kwargs):
102+
self.stack.__exit__(*args, **kwargs)
60103

61104

62105
def low_bit_sanity_check(model_path):
@@ -76,31 +119,43 @@ def low_bit_sanity_check(model_path):
76119
return low_bit
77120

78121

79-
def load_low_bit(model_or_creator, model_path, **kwargs):
80-
is_creator = not isinstance(model_or_creator, torch.nn.Module) \
81-
and callable(model_or_creator)
82-
low_bit = low_bit_sanity_check(model_path)
122+
@contextmanager
123+
def low_memory_init():
124+
init_contexts = []
125+
init_contexts.extend([init_empty_weights(), DisableTorchAllocTensor()])
126+
# Load everything except Tensors' parameters
127+
init_contexts.append(LazyLoadTensors())
128+
# As we have muted the `torch.load`, this will trigger a key missing warning in hf
129+
# but this matters not for we will load again later.
130+
init_contexts.append(MuteHFLogger(logger=transformers.modeling_utils.logger))
131+
with ContextManagers(init_contexts):
132+
yield
133+
83134

135+
def load_low_bit(model, model_path):
136+
low_bit = low_bit_sanity_check(model_path)
137+
invalidInputError(isinstance(model, torch.nn.Module),
138+
"model should be a instance of "
139+
f"`torch.nn.Module`, but got {type(model)} at last.")
84140
if low_bit:
85-
# a creator
86-
if is_creator:
87-
with init_empty_weights(), DisableTorchAllocTensor():
88-
model = model_or_creator(**kwargs)
89-
else:
90-
model = model_or_creator
91-
invalidInputError(isinstance(model, torch.nn.Module),
92-
"model_or_creator should be a instance of "
93-
"`torch.nn.Module`or a method that returns "
94-
f"an instance of `torch.nn.Module`, but got {type(model)} at last.")
95141
qtype = ggml_tensor_qtype[low_bit]
96142
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
97143

98-
state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME))
99-
if is_creator:
144+
resolved_archive_file, is_sharded = extract_local_archive_file(model_path, subfolder="")
145+
if is_sharded:
146+
# For now only shards transformers models
147+
# can run in this branch.
148+
resolved_archive_file, _ = \
149+
get_local_shard_files(model_path,
150+
resolved_archive_file,
151+
subfolder="")
152+
else:
153+
resolved_archive_file = [os.path.join(model_path, PYTORCH_MODEL_NAME)]
154+
155+
for model_file in resolved_archive_file:
156+
state_dict = torch.load(model_file)
100157
for param_name, param in state_dict.items():
101158
set_module_tensor_to_device(model, param_name, "cpu", param)
102-
else:
103-
model.load_state_dict(state_dict=state_dict)
104159
return model
105160

106161

python/llm/src/bigdl/llm/transformers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
5656

5757

58-
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
58+
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant=None):
5959
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
6060
if os.path.isfile(
6161
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
# ===========================================================================
17+
#
18+
# This file is adapted from
19+
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L516
20+
#
21+
# MIT License
22+
#
23+
# Copyright (c) 2023 Georgi Gerganov
24+
#
25+
# Permission is hereby granted, free of charge, to any person obtaining a copy
26+
# of this software and associated documentation files (the "Software"), to deal
27+
# in the Software without restriction, including without limitation the rights
28+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
29+
# copies of the Software, and to permit persons to whom the Software is
30+
# furnished to do so, subject to the following conditions:
31+
#
32+
# The above copyright notice and this permission notice shall be included in all
33+
# copies or substantial portions of the Software.
34+
#
35+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
36+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
37+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
38+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
39+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
40+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
41+
# SOFTWARE.
42+
43+
44+
import torch
45+
from torch.serialization import StorageType
46+
import pickle
47+
import zipfile
48+
import io
49+
from typing import Dict, IO, Any, Callable
50+
from dataclasses import dataclass
51+
from .common import invalidInputError
52+
53+
54+
item_size = {torch.bfloat16: 2,
55+
torch.float16: 2,
56+
torch.int: 4,
57+
torch.float: 4,
58+
torch.float32: 4,
59+
torch.int8: 1}
60+
61+
62+
@dataclass
63+
class LazyStorage:
64+
load: Callable[[int, int], torch.Tensor]
65+
kind: StorageType
66+
description: str
67+
68+
69+
@dataclass
70+
class LazyTensor:
71+
_load: Callable[[], torch.Tensor]
72+
shape: list[int]
73+
data_type: torch.dtype
74+
description: str
75+
76+
def load(self) -> torch.Tensor:
77+
ret = self._load()
78+
return ret
79+
80+
def to(self, data_type):
81+
# self.validate_conversion_to(data_type)
82+
83+
def load() -> torch.Tensor:
84+
print(f"to {data_type}")
85+
return self.load().to(data_type)
86+
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
87+
88+
89+
def _load(pickle_fp, map_location, picklemoudle, pickle_file='data.pkl', zip_file=None):
90+
91+
load_module_mapping: Dict[str, str] = {
92+
'torch.tensor': 'torch._tensor'
93+
}
94+
95+
class LazyUnpickler(picklemoudle.Unpickler):
96+
def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile):
97+
super().__init__(fp)
98+
self.data_base_path = data_base_path
99+
self.zip_file = zip_file
100+
101+
def persistent_load(self, pid):
102+
data_type = pid[1].dtype
103+
filename_stem = pid[2]
104+
filename = f'{self.data_base_path}/{filename_stem}'
105+
info = self.zip_file.getinfo(filename)
106+
107+
def load(offset: int, elm_count: int):
108+
dtype = data_type
109+
fp = self.zip_file.open(info)
110+
fp.seek(offset * item_size[dtype])
111+
size = elm_count * item_size[dtype]
112+
data = fp.read(size)
113+
return torch.frombuffer(bytearray(data), dtype=dtype)
114+
description = f'storage data_type={data_type} ' \
115+
'path-in-zip={filename} path={self.zip_file.filename}'
116+
return LazyStorage(load=load, kind=pid[1], description=description)
117+
118+
@staticmethod
119+
def lazy_rebuild_tensor_v2(storage: Any,
120+
storage_offset: Any,
121+
size: Any,
122+
stride: Any,
123+
requires_grad: Any,
124+
backward_hooks: Any,
125+
metadata: Any = None) -> LazyTensor:
126+
invalidInputError(isinstance(storage, LazyStorage),
127+
"storage should be an instance of class `LazyStorage`, "
128+
f"but get {type(storage)}.")
129+
130+
def load() -> torch.Tensor:
131+
elm_count = stride[0] * size[0]
132+
return storage.load(storage_offset, elm_count).reshape(size)
133+
description = f'pickled storage_offset={storage_offset} in {storage.description}'
134+
return LazyTensor(load, list(size), storage.kind.dtype, description)
135+
136+
@staticmethod
137+
def rebuild_from_type_v2(func, new_type, args, state):
138+
return func(*args)
139+
140+
CLASSES: dict[tuple[str, str], Any] = {
141+
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
142+
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
143+
('torch', 'Tensor'): LazyTensor,
144+
}
145+
146+
def find_class(self, mod_name, name):
147+
if (mod_name, name) in self.CLASSES:
148+
return self.CLASSES[(mod_name, name)]
149+
if type(name) is str and 'Storage' in name:
150+
try:
151+
return StorageType(name)
152+
except KeyError:
153+
pass
154+
mod_name = load_module_mapping.get(mod_name, mod_name)
155+
return super().find_class(mod_name, name)
156+
157+
unpickler = LazyUnpickler(pickle_fp,
158+
data_base_path=pickle_file,
159+
zip_file=zip_file)
160+
result = unpickler.load()
161+
162+
return result
163+
164+
165+
# This can only be used on huggingface transformers loaded from a zip file.
166+
def lazyload(
167+
f,
168+
*args,
169+
**kwargs
170+
):
171+
if isinstance(f, io.BufferedIOBase):
172+
fp = f
173+
else:
174+
fp = open(f, 'rb')
175+
zf = zipfile.ZipFile(fp)
176+
pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')]
177+
invalidInputError(len(pickle_paths) == 1,
178+
"There should be only one pickle_paths found, "
179+
f"but get {pickle_paths}. ")
180+
pickle_fp = zf.open(pickle_paths[0], 'r')
181+
state_dict = _load(pickle_fp, None, pickle, pickle_file=pickle_paths[0][:-4], zip_file=zf)
182+
return state_dict
183+
184+
185+
class LazyLoadTensors:
186+
def __init__(self):
187+
self.torch_load = torch.load
188+
189+
def __enter__(self):
190+
torch.load = lazyload
191+
192+
def __exit__(self, exc_type, exc_value, traceback):
193+
torch.load = self.torch_load

python/llm/test/convert/test_convert_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from bigdl.llm import llm_convert
2424
from bigdl.llm.transformers import AutoModelForCausalLM
25+
from bigdl.llm.optimize import optimize_model, load_low_bit, low_memory_init
2526

2627

2728
llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
@@ -87,5 +88,22 @@ def test_transformer_convert_llama_save_load(self):
8788
newModel = AutoModelForCausalLM.load_low_bit(tempdir)
8889
assert newModel is not None
8990

91+
def test_optimize_transformers_llama(self):
92+
from transformers import AutoModelForCausalLM as AutoCLM
93+
with tempfile.TemporaryDirectory(dir=output_dir) as tempdir:
94+
model = AutoCLM.from_pretrained(llama_model_path,
95+
torch_dtype="auto",
96+
low_cpu_mem_usage=True,
97+
trust_remote_code=True)
98+
model = optimize_model(model)
99+
model.save_low_bit(tempdir)
100+
with low_memory_init():
101+
new_model = AutoCLM.from_pretrained(tempdir,
102+
torch_dtype="auto",
103+
trust_remote_code=True)
104+
new_model = load_low_bit(new_model,
105+
model_path=tempdir)
106+
assert new_model is not None
107+
90108
if __name__ == '__main__':
91109
pytest.main([__file__])

0 commit comments

Comments
 (0)