Skip to content

Commit d627ee9

Browse files
committed
Support masks for ip2p models
1 parent 63e6377 commit d627ee9

File tree

1 file changed

+45
-40
lines changed

1 file changed

+45
-40
lines changed

stable-diffusion.cpp

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ class StableDiffusionGGML {
835835
int start_merge_step,
836836
SDCondition id_cond,
837837
std::vector<ggml_tensor*> ref_latents = {},
838-
ggml_tensor* noise_mask = nullptr) {
838+
ggml_tensor* denoise_mask = nullptr) {
839839
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
840840

841841
// TODO (Pix2Pix): separate image guidance params (right now it's reusing distilled guidance)
@@ -1055,10 +1055,10 @@ class StableDiffusionGGML {
10551055
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
10561056
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
10571057
}
1058-
if (noise_mask != nullptr) {
1058+
if (denoise_mask != nullptr) {
10591059
for (int64_t x = 0; x < denoised->ne[0]; x++) {
10601060
for (int64_t y = 0; y < denoised->ne[1]; y++) {
1061-
float mask = ggml_tensor_get_f32(noise_mask, x, y);
1061+
float mask = ggml_tensor_get_f32(denoise_mask, x, y);
10621062
for (int64_t k = 0; k < denoised->ne[2]; k++) {
10631063
float init = ggml_tensor_get_f32(init_latent, x, y, k);
10641064
float den = ggml_tensor_get_f32(denoised, x, y, k);
@@ -1319,7 +1319,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13191319
bool normalize_input,
13201320
std::string input_id_images_path,
13211321
std::vector<ggml_tensor*> ref_latents,
1322-
ggml_tensor* masked_latent = NULL) {
1322+
ggml_tensor* concat_latent = NULL,
1323+
ggml_tensor* denoise_mask = NULL) {
13231324
if (seed < 0) {
13241325
// Generally, when using the provided command line, the seed is always >0.
13251326
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1506,7 +1507,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15061507
int W = width / 8;
15071508
int H = height / 8;
15081509
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
1509-
ggml_tensor* noise_mask = nullptr;
15101510
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
15111511
int64_t mask_channels = 1;
15121512
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
@@ -1532,21 +1532,22 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15321532
}
15331533
}
15341534
}
1535-
if (masked_latent == NULL) {
1536-
masked_latent = empty_latent;
1535+
if (concat_latent == NULL) {
1536+
concat_latent = empty_latent;
15371537
}
1538-
cond.c_concat = masked_latent;
1538+
cond.c_concat = concat_latent;
15391539
uncond.c_concat = empty_latent;
1540-
// noise_mask = masked_latent;
1540+
denoise_mask = NULL;
15411541
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
1542-
cond.c_concat = masked_latent;
1543-
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], masked_latent->ne[1], masked_latent->ne[2], masked_latent->ne[3]);
1542+
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], init_latent->ne[3]);
15441543
ggml_set_f32(empty_latent, 0);
15451544
uncond.c_concat = empty_latent;
1546-
} else {
1547-
noise_mask = masked_latent;
1548-
}
1545+
if (concat_latent == NULL) {
1546+
concat_latent = empty_latent;
1547+
}
1548+
cond.c_concat = concat_latent;
15491549

1550+
}
15501551
for (int b = 0; b < batch_count; b++) {
15511552
int64_t sampling_start = ggml_time_ms();
15521553
int64_t cur_seed = seed + b;
@@ -1582,7 +1583,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15821583
start_merge_step,
15831584
id_cond,
15841585
ref_latents,
1585-
noise_mask);
1586+
denoise_mask);
15861587

15871588
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
15881589
// print_ggml_tensor(x_0);
@@ -1802,7 +1803,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18021803

18031804
sd_image_to_tensor(init_image.data, init_img);
18041805

1805-
ggml_tensor* masked_latent;
1806+
ggml_tensor* concat_latent;
1807+
ggml_tensor* denoise_mask = NULL;
18061808

18071809
ggml_tensor* init_latent = NULL;
18081810
ggml_tensor* init_moments = NULL;
@@ -1822,63 +1824,65 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18221824
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
18231825
sd_image_to_tensor(init_image.data, init_img);
18241826
sd_apply_mask(init_img, mask_img, masked_img);
1825-
ggml_tensor* masked_latent_0 = NULL;
1827+
ggml_tensor* masked_latent = NULL;
18261828
if (!sd_ctx->sd->use_tiny_autoencoder) {
18271829
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1828-
masked_latent_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1830+
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
18291831
} else {
1830-
masked_latent_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1832+
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
18311833
}
1832-
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent_0->ne[0], masked_latent_0->ne[1], mask_channels + masked_latent_0->ne[2], 1);
1833-
for (int ix = 0; ix < masked_latent_0->ne[0]; ix++) {
1834-
for (int iy = 0; iy < masked_latent_0->ne[1]; iy++) {
1834+
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], masked_latent->ne[1], mask_channels + masked_latent->ne[2], 1);
1835+
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {
1836+
for (int iy = 0; iy < masked_latent->ne[1]; iy++) {
18351837
int mx = ix * 8;
18361838
int my = iy * 8;
18371839
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1838-
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
1839-
float v = ggml_tensor_get_f32(masked_latent_0, ix, iy, k);
1840-
ggml_tensor_set_f32(masked_latent, v, ix, iy, k);
1840+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1841+
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
1842+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
18411843
}
18421844
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
18431845
for (int x = 0; x < 8; x++) {
18441846
for (int y = 0; y < 8; y++) {
18451847
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
18461848
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
18471849
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
1848-
ggml_tensor_set_f32(masked_latent, m, ix, iy, masked_latent_0->ne[2] + x * 8 + y);
1850+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
18491851
}
18501852
}
18511853
} else {
18521854
float m = ggml_tensor_get_f32(mask_img, mx, my);
1853-
ggml_tensor_set_f32(masked_latent, m, ix, iy, 0);
1854-
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
1855-
float v = ggml_tensor_get_f32(masked_latent_0, ix, iy, k);
1856-
ggml_tensor_set_f32(masked_latent, v, ix, iy, k + mask_channels);
1855+
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);
1856+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1857+
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
1858+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels);
18571859
}
18581860
}
18591861
}
18601862
}
18611863
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
1862-
// Not actually masked, we're just highjacking the masked_latent variable since it will be used the same way
1864+
// Not actually masked, we're just highjacking the concat_latent variable since it will be used the same way
18631865
if (!sd_ctx->sd->use_tiny_autoencoder) {
18641866
if (sd_ctx->sd->is_using_edm_v_parameterization) {
18651867
// for CosXL edit
1866-
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
1868+
concat_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
18671869
} else {
1868-
masked_latent = sd_ctx->sd->get_first_stage_encoding_mode(work_ctx, init_moments);
1870+
concat_latent = sd_ctx->sd->get_first_stage_encoding_mode(work_ctx, init_moments);
18691871
}
18701872
} else {
1871-
masked_latent = init_latent;
1873+
concat_latent = init_latent;
18721874
}
1873-
} else {
1875+
}
1876+
1877+
{
18741878
// LOG_WARN("Inpainting with a base model is not great");
1875-
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
1876-
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {
1877-
for (int iy = 0; iy < masked_latent->ne[1]; iy++) {
1879+
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
1880+
for (int ix = 0; ix < denoise_mask->ne[0]; ix++) {
1881+
for (int iy = 0; iy < denoise_mask->ne[1]; iy++) {
18781882
int mx = ix * 8;
18791883
int my = iy * 8;
18801884
float m = ggml_tensor_get_f32(mask_img, mx, my);
1881-
ggml_tensor_set_f32(masked_latent, m, ix, iy);
1885+
ggml_tensor_set_f32(denoise_mask, m, ix, iy);
18821886
}
18831887
}
18841888
}
@@ -1915,7 +1919,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
19151919
normalize_input,
19161920
input_id_images_path_c_str,
19171921
{},
1918-
masked_latent);
1922+
concat_latent,
1923+
denoise_mask);
19191924

19201925
size_t t2 = ggml_time_ms();
19211926

0 commit comments

Comments
 (0)