Skip to content

Commit 6d84a30

Browse files
authored
feat: overriding quant types for specific tensors on model conversion (#724)
1 parent dafc32d commit 6d84a30

File tree

4 files changed

+71
-14
lines changed

4 files changed

+71
-14
lines changed

examples/cli/main.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ struct SDParams {
8787
std::string stacked_id_embeddings_path;
8888
std::string input_id_images_path;
8989
sd_type_t wtype = SD_TYPE_COUNT;
90+
std::string tensor_type_rules;
9091
std::string lora_model_dir;
9192
std::string output_path = "output.png";
9293
std::string input_path;
@@ -223,6 +224,7 @@ void print_usage(int argc, const char* argv[]) {
223224
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
224225
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
225226
printf(" If not specified, the default is the type of the weight file\n");
227+
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
226228
printf(" --lora-model-dir [DIR] lora model directory\n");
227229
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
228230
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
@@ -404,6 +406,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
404406
valid_types.c_str());
405407
exit(1);
406408
}
409+
} else if (arg == "--tensor-type-rules") {
410+
if (++i >= argc) {
411+
invalid_arg = true;
412+
break;
413+
}
414+
params.tensor_type_rules = argv[i];
407415
} else if (arg == "--lora-model-dir") {
408416
if (++i >= argc) {
409417
invalid_arg = true;
@@ -733,6 +741,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
733741
exit(1);
734742
}
735743

744+
if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
745+
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
746+
}
747+
736748
if (params.seed < 0) {
737749
srand((int)time(NULL));
738750
params.seed = rand();
@@ -845,7 +857,7 @@ int main(int argc, const char* argv[]) {
845857
}
846858

847859
if (params.mode == CONVERT) {
848-
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
860+
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_type_rules.c_str());
849861
if (!success) {
850862
fprintf(stderr,
851863
"convert '%s'/'%s' to '%s' failed\n",

model.cpp

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ const char* unused_tensors[] = {
100100
"model_ema.diffusion_model",
101101
"embedding_manager",
102102
"denoiser.sigmas",
103-
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
103+
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
104104
};
105105

106106
bool is_unused_tensor(std::string name) {
@@ -1169,7 +1169,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11691169
n_dims = 1;
11701170
}
11711171

1172-
11731172
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
11741173
tensor_storage.reverse_ne();
11751174

@@ -1914,7 +1913,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19141913
};
19151914
int tensor_count = 0;
19161915
int64_t t1 = ggml_time_ms();
1917-
bool partial = false;
1916+
bool partial = false;
19181917
for (auto& tensor_storage : processed_tensor_storages) {
19191918
if (tensor_storage.file_index != file_index) {
19201919
++tensor_count;
@@ -1997,9 +1996,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19971996
}
19981997
}
19991998
size_t tensor_max = processed_tensor_storages.size();
2000-
int64_t t2 = ggml_time_ms();
1999+
int64_t t2 = ggml_time_ms();
20012000
pretty_progress(++tensor_count, tensor_max, (t2 - t1) / 1000.0f);
2002-
t1 = t2;
2001+
t1 = t2;
20032002
partial = tensor_count != tensor_max;
20042003
}
20052004

@@ -2088,6 +2087,41 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
20882087
return true;
20892088
}
20902089

2090+
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
2091+
std::vector<std::pair<std::string, ggml_type>> result;
2092+
for (const auto& item : splitString(tensor_type_rules, ',')) {
2093+
if (item.size() == 0)
2094+
continue;
2095+
std::string::size_type pos = item.find('=');
2096+
if (pos == std::string::npos) {
2097+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
2098+
continue;
2099+
}
2100+
std::string tensor_pattern = item.substr(0, pos);
2101+
std::string type_name = item.substr(pos + 1);
2102+
2103+
ggml_type tensor_type = GGML_TYPE_COUNT;
2104+
2105+
if (type_name == "f32") {
2106+
tensor_type = GGML_TYPE_F32;
2107+
} else {
2108+
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
2109+
auto trait = ggml_get_type_traits((ggml_type)i);
2110+
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
2111+
tensor_type = (ggml_type)i;
2112+
}
2113+
}
2114+
}
2115+
2116+
if (tensor_type != GGML_TYPE_COUNT) {
2117+
result.emplace_back(tensor_pattern, tensor_type);
2118+
} else {
2119+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
2120+
}
2121+
}
2122+
return result;
2123+
}
2124+
20912125
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
20922126
const std::string& name = tensor_storage.name;
20932127
if (type != GGML_TYPE_COUNT) {
@@ -2119,7 +2153,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
21192153
return false;
21202154
}
21212155

2122-
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) {
2156+
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
21232157
auto backend = ggml_backend_cpu_init();
21242158
size_t mem_size = 1 * 1024 * 1024; // for padding
21252159
mem_size += tensor_storages.size() * ggml_tensor_overhead();
@@ -2129,12 +2163,23 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
21292163

21302164
gguf_context* gguf_ctx = gguf_init_empty();
21312165

2166+
auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);
2167+
21322168
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
21332169
const std::string& name = tensor_storage.name;
2170+
ggml_type tensor_type = tensor_storage.type;
2171+
ggml_type dst_type = type;
21342172

2135-
ggml_type tensor_type = tensor_storage.type;
2136-
if (tensor_should_be_converted(tensor_storage, type)) {
2137-
tensor_type = type;
2173+
for (const auto& tensor_type_rule : tensor_type_rules) {
2174+
std::regex pattern(tensor_type_rule.first);
2175+
if (std::regex_search(name, pattern)) {
2176+
dst_type = tensor_type_rule.second;
2177+
break;
2178+
}
2179+
}
2180+
2181+
if (tensor_should_be_converted(tensor_storage, dst_type)) {
2182+
tensor_type = dst_type;
21382183
}
21392184

21402185
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
@@ -2193,7 +2238,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
21932238
return mem_size;
21942239
}
21952240

2196-
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) {
2241+
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* tensor_type_rules) {
21972242
ModelLoader model_loader;
21982243

21992244
if (!model_loader.init_from_file(input_path)) {
@@ -2207,6 +2252,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
22072252
return false;
22082253
}
22092254
}
2210-
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
2255+
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
22112256
return success;
22122257
}

model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class ModelLoader {
222222
ggml_backend_t backend,
223223
std::set<std::string> ignore_tensors = {});
224224

225-
bool save_to_gguf_file(const std::string& file_path, ggml_type type);
225+
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
226226
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
227227
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
228228
~ModelLoader() = default;

stable-diffusion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
257257

258258
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
259259

260-
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type);
260+
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type, const char* tensor_type_rules);
261261

262262
SD_API uint8_t* preprocess_canny(uint8_t* img,
263263
int width,

0 commit comments

Comments
 (0)