Skip to content

Commit ebbd1f4

Browse files
committed
docs: add notes about conversion dequantization in convert_hf_to_gguf.py
1 parent a3016c9 commit ebbd1f4

File tree

1 file changed

+198
-0
lines changed

1 file changed

+198
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
## convert_hf_to_gguf.py dequantization
2+
Some models are already quantized and to convert them we need to dequantized them in
3+
llama.cpp before converting to gguf format.
4+
5+
So the BaseModel in convert_hf_to_gguf.py has a number of members one is:
6+
```python
7+
class ModelBase:
8+
...
9+
10+
model_tensors: dict[str, Callable[[], Tensor]]
11+
```
12+
This is a dictionary of tensor names to functions that return the tensor.
13+
14+
In the constructor we then have:
15+
```python
16+
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
17+
```
18+
So this allows passing in a remote model id, a hugging face model id, or None in which case
19+
a local model is used (on the local disk):
20+
```python
21+
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
22+
tensors: dict[str, Callable[[], Tensor]] = {}
23+
```
24+
And we can see the return type is the dictionary of tensor names to functions that return tensors.
25+
26+
If remote tensors are used we have the following code:
27+
```python
28+
if remote_hf_model_id is not None:
29+
is_safetensors = True
30+
31+
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
32+
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
33+
for name, remote_tensor in remote_tensors.items():
34+
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
35+
36+
return tensors
37+
```
38+
So that is how remote models are handled (notice the return statement).
39+
40+
For local models we have:
41+
```python
42+
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
43+
is_safetensors: bool = len(part_names) > 0
44+
if not is_safetensors:
45+
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
46+
```
47+
part_names is just the list of weight files, so all the .safetensor files in the model directory, or
48+
fallback to .bin files.
49+
50+
```python
51+
tensor_names_from_index: set[str] = set()
52+
53+
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
54+
index_name += ".index.json"
55+
index_file = self.dir_model / index_name
56+
```
57+
That last line is joining the Path in a platform independent way.
58+
59+
If the index file exist the it will be opened and the weight map loaded:
60+
```python
61+
if index_file.is_file():
62+
logger.info(f"gguf: loading model weight map from '{index_name}'")
63+
with open(index_file, "r", encoding="utf-8") as f:
64+
index: dict[str, Any] = json.load(f)
65+
weight_map = index.get("weight_map")
66+
if weight_map is None or not isinstance(weight_map, dict):
67+
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
68+
tensor_names_from_index.update(weight_map.keys())
69+
else:
70+
weight_map = {}
71+
```
72+
Next we will iterate over all the weight files:
73+
```python
74+
for part_name in part_names:
75+
logger.info(f"gguf: indexing model part '{part_name}'")
76+
ctx: ContextManager[Any]
77+
if is_safetensors:
78+
from safetensors import safe_open
79+
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
80+
else:
81+
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
82+
```
83+
ContextManager is a Python type that allows us to use the "with" statement to manage resources. safe_open
84+
returns a ContextManager that opens the safetensor file and allows us iterate over the tensors in it without
85+
loading all the tensors into memory at once.
86+
torch.load does not return a ContextManager but it is wrapped in a dummy context manager using contextlib.nullcontext
87+
so that it can be treated as one.
88+
89+
```python
90+
with ctx as model_part:
91+
assert model_part is not None
92+
93+
# iterate over all the tensors in this part (weight file)
94+
for name in model_part.keys():
95+
if is_safetensors:
96+
if self.lazy:
97+
data = model_part.get_slice(name)
98+
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
99+
else:
100+
data = model_part.get_tensor(name)
101+
data_gen = lambda data=data: data # noqa: E731
102+
else:
103+
data = model_part[name]
104+
if self.lazy:
105+
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
106+
else:
107+
data_gen = lambda data=data: data # noqa: E731
108+
tensors[name] = data_gen
109+
```
110+
lambda data=data: uses a default argument to capture the current value of data. Without this, all
111+
lambdas in the loop would close over the same data variable and end up returning the last one.
112+
And notice that the tensors[name] is set to data_gen which is a function that returns the tensor data.
113+
114+
And finally we have the consistency check:
115+
```python
116+
if len(tensor_names_from_index) > 0:
117+
tensor_names_from_parts = set(tensors.keys())
118+
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
119+
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
120+
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
121+
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
122+
if len(extra) == 0 and len(missing_files) > 0:
123+
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
124+
f"Missing tensors: {missing}")
125+
else:
126+
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
127+
f"Missing tensors: {missing}\n"
128+
f"Extra tensors: {extra}")
129+
130+
return tensors
131+
```
132+
133+
Back in the constructor we later have:
134+
```python
135+
self.dequant_model()
136+
```
137+
And in dequant_model we first check if the model configuration contains a quantization_config:
138+
```python
139+
def dequant_model(self):
140+
tensors_to_remove: list[str] = []
141+
new_tensors: dict[str, Callable[[], Tensor]] = {}
142+
143+
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
144+
quant_method = quant_config.get("quant_method")
145+
```
146+
For example a model might contains something like this:
147+
```console
148+
"quantization_config": {
149+
"activation_scheme": "dynamic",
150+
"modules_to_not_convert": null,
151+
"quant_method": "fp8",
152+
"weight_block_size": [
153+
128,
154+
128
155+
]
156+
},
157+
```
158+
Then we have a few function defined for dequatizing differenct types of quantiztion methods:
159+
```python
160+
def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
161+
...
162+
163+
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
164+
...
165+
166+
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
167+
scale = scale.float()
168+
169+
if (weight_block_size := quant_config.get("weight_block_size")):
170+
# TODO: make sure it's a list of integers
171+
for i, size in enumerate(weight_block_size):
172+
scale = scale.repeat_interleave(size, i)
173+
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
174+
scale = scale[tuple(slice(0, size) for size in weight.shape)]
175+
176+
return weight.float() * scale
177+
```
178+
I'm showing the simple as that is the example I'm working with at the moment:
179+
```python
180+
if quant_method == "bitnet":
181+
for name in self.model_tensors.keys():
182+
if name.endswith(".weight_scale"):
183+
weight_name = name.removesuffix("_scale")
184+
w = self.model_tensors[weight_name]
185+
s = self.model_tensors[name]
186+
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
187+
tensors_to_remove.append(name)
188+
elif quant_method == "fp8":
189+
for name in self.model_tensors.keys():
190+
if name.endswith(".weight_scale_inv"):
191+
weight_name = name.removesuffix("_scale_inv")
192+
w = self.model_tensors[weight_name]
193+
s = self.model_tensors[name]
194+
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
195+
tensors_to_remove.append(name)
196+
```
197+
So a model that has quantized weights will also have scale tensors that are used to dequantize the weights.
198+
And these are the inverse scale/delta so to dequantize we just multiply the weights by the scale.

0 commit comments

Comments
 (0)