Skip to content

Instruct-Pix2pix/CosXL-Edit support #679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 42 additions & 36 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,16 @@ struct SDParams {

std::string prompt;
std::string negative_prompt;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
float guidance = 3.5f;
float eta = 0.f;
float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int batch_count = 1;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
float img_cfg_scale = INFINITY;
float guidance = 3.5f;
float eta = 0.f;
float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int batch_count = 1;

int video_frames = 6;
int motion_bucket_id = 127;
Expand Down Expand Up @@ -175,6 +176,7 @@ void print_params(SDParams params) {
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" eta: %.2f\n", params.eta);
Expand Down Expand Up @@ -232,7 +234,8 @@ void print_usage(int argc, const char* argv[]) {
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
printf(" --img-cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n");
printf(" --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5)\n");
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
Expand Down Expand Up @@ -462,6 +465,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.cfg_scale = std::stof(argv[i]);
} else if (arg == "--img-cfg-scale") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.img_cfg_scale = std::stof(argv[i]);
} else if (arg == "--guidance") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -743,6 +752,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.output_path = "output.gguf";
}
}

if (!isfinite(params.img_cfg_scale)) {
params.img_cfg_scale = params.cfg_scale;
}
}

static std::string sd_basename(const std::string& path) {
Expand Down Expand Up @@ -837,6 +850,18 @@ int main(int argc, const char* argv[]) {

parse_args(argc, argv, params);

sd_guidance_params_t guidance_params = {params.cfg_scale,
params.img_cfg_scale,
params.min_cfg,
params.guidance,
{
params.skip_layers.data(),
params.skip_layers.size(),
params.skip_layer_start,
params.skip_layer_end,
params.slg_scale,
}};

sd_set_log_callback(sd_log_cb, (void*)&params);

if (params.verbose) {
Expand Down Expand Up @@ -1029,8 +1054,7 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
guidance_params,
params.eta,
params.width,
params.height,
Expand All @@ -1042,12 +1066,7 @@ int main(int argc, const char* argv[]) {
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str(),
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
params.input_id_images_path.c_str());
} else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,
Expand All @@ -1063,8 +1082,7 @@ int main(int argc, const char* argv[]) {
params.motion_bucket_id,
params.fps,
params.augmentation_level,
params.min_cfg,
params.cfg_scale,
guidance_params,
params.sample_method,
params.sample_steps,
params.strength,
Expand Down Expand Up @@ -1097,8 +1115,7 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
guidance_params,
params.eta,
params.width,
params.height,
Expand All @@ -1111,12 +1128,7 @@ int main(int argc, const char* argv[]) {
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str(),
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
params.input_id_images_path.c_str());
}
} else { // EDIT
results = edit(sd_ctx,
Expand All @@ -1125,25 +1137,19 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
guidance_params,
params.eta,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed,
params.batch_count,
control_image,
params.control_strength,
params.style_ratio,
params.normalize_input,
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
params.input_id_images_path.c_str());
}

if (results == NULL) {
Expand Down
7 changes: 7 additions & 0 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1674,10 +1674,14 @@ SDVersion ModelLoader::get_sd_version() {
}
}
bool is_inpaint = input_block_weight.ne[2] == 9;
bool is_ip2p = input_block_weight.ne[2] == 8;
if (is_xl) {
if (is_inpaint) {
return VERSION_SDXL_INPAINT;
}
if (is_ip2p) {
return VERSION_SDXL_PIX2PIX;
}
return VERSION_SDXL;
}

Expand All @@ -1693,6 +1697,9 @@ SDVersion ModelLoader::get_sd_version() {
if (is_inpaint) {
return VERSION_SD1_INPAINT;
}
if (is_ip2p) {
return VERSION_SD1_PIX2PIX;
}
return VERSION_SD1;
} else if (token_embedding_weight.ne[0] == 1024) {
if (is_inpaint) {
Expand Down
14 changes: 12 additions & 2 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
enum SDVersion {
VERSION_SD1,
VERSION_SD1_INPAINT,
VERSION_SD1_PIX2PIX,
VERSION_SD2,
VERSION_SD2_INPAINT,
VERSION_SDXL,
VERSION_SDXL_INPAINT,
VERSION_SDXL_PIX2PIX,
VERSION_SVD,
VERSION_SD3,
VERSION_FLUX,
Expand All @@ -47,7 +49,7 @@ static inline bool sd_version_is_sd3(SDVersion version) {
}

static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
return true;
}
return false;
Expand All @@ -61,7 +63,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
}

static inline bool sd_version_is_sdxl(SDVersion version) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) {
return true;
}
return false;
Expand All @@ -81,6 +83,14 @@ static inline bool sd_version_is_dit(SDVersion version) {
return false;
}

static inline bool sd_version_is_edit(SDVersion version) {
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
}

static bool sd_version_use_concat(SDVersion version) {
return sd_version_is_edit(version) || sd_version_is_inpaint(version);
}

enum PMVersion {
PM_VERSION_1,
PM_VERSION_2,
Expand Down
Loading
Loading