Skip to content

Commit 5650e56

Browse files
committed
Squash instruct-Pix2Pix(#leejet#679) support
support 2 conditionings cfg Do not re-encode the exact same image twice fixes for 2-cfg Fix pix2pix latent inputs + improve inpainting a bit + fix naming prepare for other pix2pix-like models Support sdxl ip2p fix reference image embeddings Support 2-cond cfg properly in cli fix typo in help Support masks for ip2p models
1 parent 6d84a30 commit 5650e56

File tree

6 files changed

+268
-201
lines changed

6 files changed

+268
-201
lines changed

examples/cli/main.cpp

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,16 @@ struct SDParams {
9797

9898
std::string prompt;
9999
std::string negative_prompt;
100-
float min_cfg = 1.0f;
101-
float cfg_scale = 7.0f;
102-
float guidance = 3.5f;
103-
float eta = 0.f;
104-
float style_ratio = 20.f;
105-
int clip_skip = -1; // <= 0 represents unspecified
106-
int width = 512;
107-
int height = 512;
108-
int batch_count = 1;
100+
float min_cfg = 1.0f;
101+
float cfg_scale = 7.0f;
102+
float img_cfg_scale = INFINITY;
103+
float guidance = 3.5f;
104+
float eta = 0.f;
105+
float style_ratio = 20.f;
106+
int clip_skip = -1; // <= 0 represents unspecified
107+
int width = 512;
108+
int height = 512;
109+
int batch_count = 1;
109110

110111
int video_frames = 6;
111112
int motion_bucket_id = 127;
@@ -176,6 +177,7 @@ void print_params(SDParams params) {
176177
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
177178
printf(" min_cfg: %.2f\n", params.min_cfg);
178179
printf(" cfg_scale: %.2f\n", params.cfg_scale);
180+
printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale);
179181
printf(" slg_scale: %.2f\n", params.slg_scale);
180182
printf(" guidance: %.2f\n", params.guidance);
181183
printf(" eta: %.2f\n", params.eta);
@@ -234,7 +236,8 @@ void print_usage(int argc, const char* argv[]) {
234236
printf(" -p, --prompt [PROMPT] the prompt to render\n");
235237
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
236238
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
237-
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
239+
printf(" --img-cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n");
240+
printf(" --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5)\n");
238241
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
239242
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
240243
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
@@ -470,6 +473,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
470473
break;
471474
}
472475
params.cfg_scale = std::stof(argv[i]);
476+
} else if (arg == "--img-cfg-scale") {
477+
if (++i >= argc) {
478+
invalid_arg = true;
479+
break;
480+
}
481+
params.img_cfg_scale = std::stof(argv[i]);
473482
} else if (arg == "--guidance") {
474483
if (++i >= argc) {
475484
invalid_arg = true;
@@ -755,6 +764,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
755764
params.output_path = "output.gguf";
756765
}
757766
}
767+
768+
if (!isfinite(params.img_cfg_scale)) {
769+
params.img_cfg_scale = params.cfg_scale;
770+
}
758771
}
759772

760773
static std::string sd_basename(const std::string& path) {
@@ -849,6 +862,18 @@ int main(int argc, const char* argv[]) {
849862

850863
parse_args(argc, argv, params);
851864

865+
sd_guidance_params_t guidance_params = {params.cfg_scale,
866+
params.img_cfg_scale,
867+
params.min_cfg,
868+
params.guidance,
869+
{
870+
params.skip_layers.data(),
871+
params.skip_layers.size(),
872+
params.skip_layer_start,
873+
params.skip_layer_end,
874+
params.slg_scale,
875+
}};
876+
852877
sd_set_log_callback(sd_log_cb, (void*)&params);
853878

854879
if (params.verbose) {
@@ -1041,8 +1066,7 @@ int main(int argc, const char* argv[]) {
10411066
params.prompt.c_str(),
10421067
params.negative_prompt.c_str(),
10431068
params.clip_skip,
1044-
params.cfg_scale,
1045-
params.guidance,
1069+
guidance_params,
10461070
params.eta,
10471071
params.width,
10481072
params.height,
@@ -1054,12 +1078,7 @@ int main(int argc, const char* argv[]) {
10541078
params.control_strength,
10551079
params.style_ratio,
10561080
params.normalize_input,
1057-
params.input_id_images_path.c_str(),
1058-
params.skip_layers.data(),
1059-
params.skip_layers.size(),
1060-
params.slg_scale,
1061-
params.skip_layer_start,
1062-
params.skip_layer_end);
1081+
params.input_id_images_path.c_str());
10631082
} else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
10641083
sd_image_t input_image = {(uint32_t)params.width,
10651084
(uint32_t)params.height,
@@ -1075,8 +1094,7 @@ int main(int argc, const char* argv[]) {
10751094
params.motion_bucket_id,
10761095
params.fps,
10771096
params.augmentation_level,
1078-
params.min_cfg,
1079-
params.cfg_scale,
1097+
guidance_params,
10801098
params.sample_method,
10811099
params.sample_steps,
10821100
params.strength,
@@ -1109,8 +1127,7 @@ int main(int argc, const char* argv[]) {
11091127
params.prompt.c_str(),
11101128
params.negative_prompt.c_str(),
11111129
params.clip_skip,
1112-
params.cfg_scale,
1113-
params.guidance,
1130+
guidance_params,
11141131
params.eta,
11151132
params.width,
11161133
params.height,
@@ -1123,12 +1140,7 @@ int main(int argc, const char* argv[]) {
11231140
params.control_strength,
11241141
params.style_ratio,
11251142
params.normalize_input,
1126-
params.input_id_images_path.c_str(),
1127-
params.skip_layers.data(),
1128-
params.skip_layers.size(),
1129-
params.slg_scale,
1130-
params.skip_layer_start,
1131-
params.skip_layer_end);
1143+
params.input_id_images_path.c_str());
11321144
}
11331145
} else { // EDIT
11341146
results = edit(sd_ctx,
@@ -1137,25 +1149,19 @@ int main(int argc, const char* argv[]) {
11371149
params.prompt.c_str(),
11381150
params.negative_prompt.c_str(),
11391151
params.clip_skip,
1140-
params.cfg_scale,
1141-
params.guidance,
1152+
guidance_params,
11421153
params.eta,
11431154
params.width,
11441155
params.height,
11451156
params.sample_method,
11461157
params.sample_steps,
1147-
params.strength,
11481158
params.seed,
11491159
params.batch_count,
11501160
control_image,
11511161
params.control_strength,
11521162
params.style_ratio,
11531163
params.normalize_input,
1154-
params.skip_layers.data(),
1155-
params.skip_layers.size(),
1156-
params.slg_scale,
1157-
params.skip_layer_start,
1158-
params.skip_layer_end);
1164+
params.input_id_images_path.c_str());
11591165
}
11601166

11611167
if (results == NULL) {

model.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,10 +1673,14 @@ SDVersion ModelLoader::get_sd_version() {
16731673
}
16741674
}
16751675
bool is_inpaint = input_block_weight.ne[2] == 9;
1676+
bool is_ip2p = input_block_weight.ne[2] == 8;
16761677
if (is_xl) {
16771678
if (is_inpaint) {
16781679
return VERSION_SDXL_INPAINT;
16791680
}
1681+
if (is_ip2p) {
1682+
return VERSION_SDXL_PIX2PIX;
1683+
}
16801684
return VERSION_SDXL;
16811685
}
16821686

@@ -1692,6 +1696,9 @@ SDVersion ModelLoader::get_sd_version() {
16921696
if (is_inpaint) {
16931697
return VERSION_SD1_INPAINT;
16941698
}
1699+
if (is_ip2p) {
1700+
return VERSION_SD1_PIX2PIX;
1701+
}
16951702
return VERSION_SD1;
16961703
} else if (token_embedding_weight.ne[0] == 1024) {
16971704
if (is_inpaint) {

model.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
enum SDVersion {
2222
VERSION_SD1,
2323
VERSION_SD1_INPAINT,
24+
VERSION_SD1_PIX2PIX,
2425
VERSION_SD2,
2526
VERSION_SD2_INPAINT,
2627
VERSION_SDXL,
2728
VERSION_SDXL_INPAINT,
29+
VERSION_SDXL_PIX2PIX,
2830
VERSION_SVD,
2931
VERSION_SD3,
3032
VERSION_FLUX,
@@ -47,7 +49,7 @@ static inline bool sd_version_is_sd3(SDVersion version) {
4749
}
4850

4951
static inline bool sd_version_is_sd1(SDVersion version) {
50-
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) {
52+
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
5153
return true;
5254
}
5355
return false;
@@ -61,7 +63,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
6163
}
6264

6365
static inline bool sd_version_is_sdxl(SDVersion version) {
64-
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) {
66+
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) {
6567
return true;
6668
}
6769
return false;
@@ -81,6 +83,14 @@ static inline bool sd_version_is_dit(SDVersion version) {
8183
return false;
8284
}
8385

86+
static inline bool sd_version_is_edit(SDVersion version) {
87+
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
88+
}
89+
90+
static bool sd_version_use_concat(SDVersion version) {
91+
return sd_version_is_edit(version) || sd_version_is_inpaint(version);
92+
}
93+
8494
enum PMVersion {
8595
PM_VERSION_1,
8696
PM_VERSION_2,

0 commit comments

Comments
 (0)