Skip to content

Commit b8888fc

Browse files
committed
Kontext refactor
1 parent 6bceb65 commit b8888fc

File tree

6 files changed

+173
-86
lines changed

6 files changed

+173
-86
lines changed

diffusion_model.hpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ struct DiffusionModel {
1313
struct ggml_tensor* c_concat,
1414
struct ggml_tensor* y,
1515
struct ggml_tensor* guidance,
16-
int num_video_frames = -1,
17-
std::vector<struct ggml_tensor*> controls = {},
18-
float control_strength = 0.f,
19-
bool kontext_concat = false,
20-
struct ggml_tensor** output = NULL,
21-
struct ggml_context* output_ctx = NULL,
22-
std::vector<int> skip_layers = std::vector<int>()) = 0;
16+
int num_video_frames = -1,
17+
std::vector<struct ggml_tensor*> controls = {},
18+
float control_strength = 0.f,
19+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
20+
struct ggml_tensor** output = NULL,
21+
struct ggml_context* output_ctx = NULL,
22+
std::vector<int> skip_layers = std::vector<int>()) = 0;
2323
virtual void alloc_params_buffer() = 0;
2424
virtual void free_params_buffer() = 0;
2525
virtual void free_compute_buffer() = 0;
@@ -69,13 +69,13 @@ struct UNetModel : public DiffusionModel {
6969
struct ggml_tensor* c_concat,
7070
struct ggml_tensor* y,
7171
struct ggml_tensor* guidance,
72-
int num_video_frames = -1,
73-
std::vector<struct ggml_tensor*> controls = {},
74-
float control_strength = 0.f,
75-
bool kontext_concat = false,
76-
struct ggml_tensor** output = NULL,
77-
struct ggml_context* output_ctx = NULL,
78-
std::vector<int> skip_layers = std::vector<int>()) {
72+
int num_video_frames = -1,
73+
std::vector<struct ggml_tensor*> controls = {},
74+
float control_strength = 0.f,
75+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
76+
struct ggml_tensor** output = NULL,
77+
struct ggml_context* output_ctx = NULL,
78+
std::vector<int> skip_layers = std::vector<int>()) {
7979
(void)skip_layers; // SLG doesn't work with UNet models
8080
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
8181
}
@@ -123,7 +123,7 @@ struct MMDiTModel : public DiffusionModel {
123123
int num_video_frames = -1,
124124
std::vector<struct ggml_tensor*> controls = {},
125125
float control_strength = 0.f,
126-
bool kontext_concat = false,
126+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
127127
struct ggml_tensor** output = NULL,
128128
struct ggml_context* output_ctx = NULL,
129129
std::vector<int> skip_layers = std::vector<int>()) {
@@ -172,14 +172,14 @@ struct FluxModel : public DiffusionModel {
172172
struct ggml_tensor* c_concat,
173173
struct ggml_tensor* y,
174174
struct ggml_tensor* guidance,
175-
int num_video_frames = -1,
176-
std::vector<struct ggml_tensor*> controls = {},
177-
float control_strength = 0.f,
178-
bool kontext_concat = false,
179-
struct ggml_tensor** output = NULL,
180-
struct ggml_context* output_ctx = NULL,
181-
std::vector<int> skip_layers = std::vector<int>()) {
182-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_concat, output, output_ctx, skip_layers);
175+
int num_video_frames = -1,
176+
std::vector<struct ggml_tensor*> controls = {},
177+
float control_strength = 0.f,
178+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
179+
struct ggml_tensor** output = NULL,
180+
struct ggml_context* output_ctx = NULL,
181+
std::vector<int> skip_layers = std::vector<int>()) {
182+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_imgs, output, output_ctx, skip_layers);
183183
}
184184
};
185185

examples/cli/main.cpp

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ struct SDParams {
9797
std::string mask_path;
9898
std::string control_image_path;
9999

100+
std::vector<std::string> kontext_image_paths;
101+
100102
std::string prompt;
101103
std::string negative_prompt;
102104
float min_cfg = 1.0f;
@@ -289,6 +291,7 @@ void print_usage(int argc, const char* argv[]) {
289291
printf(" --preview-path [PATH} path to write preview image to (default: ./preview.png)\n");
290292
printf(" --color Colors the logging tags according to level\n");
291293
printf(" -v, --verbose print extra info\n");
294+
printf(" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n");
292295
}
293296

294297
void parse_args(int argc, const char** argv, SDParams& params) {
@@ -724,6 +727,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
724727
break;
725728
}
726729
params.imatrix_in.push_back(std::string(argv[i]));
730+
} else if (arg == "-ki" || arg == "--kontext-img") {
731+
if (++i >= argc) {
732+
invalid_arg = true;
733+
break;
734+
}
735+
params.kontext_image_paths.push_back(argv[i]);
727736
} else {
728737
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
729738
print_usage(argc, argv);
@@ -958,12 +967,10 @@ int main(int argc, const char* argv[]) {
958967
params.skip_layer_end,
959968
params.slg_scale,
960969
},
961-
{
962-
params.apg_eta,
963-
params.apg_momentum,
964-
params.apg_norm_threshold,
965-
params.apg_norm_smoothing
966-
}};
970+
{params.apg_eta,
971+
params.apg_momentum,
972+
params.apg_norm_threshold,
973+
params.apg_norm_smoothing}};
967974

968975
sd_set_log_callback(sd_log_cb, (void*)&params);
969976
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval);
@@ -1007,8 +1014,40 @@ int main(int argc, const char* argv[]) {
10071014
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
10081015
return 1;
10091016
}
1010-
10111017
bool vae_decode_only = true;
1018+
1019+
std::vector<sd_image_t> kontext_imgs;
1020+
for (auto& path : params.kontext_image_paths) {
1021+
vae_decode_only = false;
1022+
int c = 0;
1023+
int width = 0;
1024+
int height = 0;
1025+
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
1026+
if (image_buffer == NULL) {
1027+
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
1028+
return 1;
1029+
}
1030+
if (c < 3) {
1031+
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
1032+
free(image_buffer);
1033+
return 1;
1034+
}
1035+
if (width <= 0) {
1036+
fprintf(stderr, "error: the width of image must be greater than 0\n");
1037+
free(image_buffer);
1038+
return 1;
1039+
}
1040+
if (height <= 0) {
1041+
fprintf(stderr, "error: the height of image must be greater than 0\n");
1042+
free(image_buffer);
1043+
return 1;
1044+
}
1045+
kontext_imgs.push_back({(uint32_t)width,
1046+
(uint32_t)height,
1047+
3,
1048+
image_buffer});
1049+
}
1050+
10121051
uint8_t* input_image_buffer = NULL;
10131052
uint8_t* control_image_buffer = NULL;
10141053
uint8_t* mask_image_buffer = NULL;
@@ -1148,7 +1187,8 @@ int main(int argc, const char* argv[]) {
11481187
params.control_strength,
11491188
params.style_ratio,
11501189
params.normalize_input,
1151-
params.input_id_images_path.c_str());
1190+
params.input_id_images_path.c_str(),
1191+
kontext_imgs.data(), kontext_imgs.size());
11521192
} else {
11531193
sd_image_t input_image = {(uint32_t)params.width,
11541194
(uint32_t)params.height,
@@ -1210,7 +1250,8 @@ int main(int argc, const char* argv[]) {
12101250
params.control_strength,
12111251
params.style_ratio,
12121252
params.normalize_input,
1213-
params.input_id_images_path.c_str());
1253+
params.input_id_images_path.c_str(),
1254+
kontext_imgs.data(), kontext_imgs.size());
12141255
}
12151256
}
12161257

flux.hpp

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -930,14 +930,13 @@ namespace Flux {
930930
}
931931

932932
struct ggml_tensor* forward(struct ggml_context* ctx,
933-
struct ggml_tensor* x,
933+
std::vector<struct ggml_tensor*> imgs,
934934
struct ggml_tensor* timestep,
935935
struct ggml_tensor* context,
936936
struct ggml_tensor* c_concat,
937937
struct ggml_tensor* y,
938938
struct ggml_tensor* guidance,
939939
struct ggml_tensor* pe,
940-
bool kontext_concat = false,
941940
struct ggml_tensor* arange = NULL,
942941
std::vector<int> skip_layers = std::vector<int>(),
943942
SDVersion version = VERSION_FLUX) {
@@ -951,19 +950,31 @@ namespace Flux {
951950
// pe: (L, d_head/2, 2, 2)
952951
// return: (N, C, H, W)
953952

953+
auto x = imgs[0];
954954
GGML_ASSERT(x->ne[3] == 1);
955955

956956
int64_t W = x->ne[0];
957957
int64_t H = x->ne[1];
958958
int64_t C = x->ne[2];
959959
int64_t patch_size = 2;
960-
int pad_h = (patch_size - H % patch_size) % patch_size;
961-
int pad_w = (patch_size - W % patch_size) % patch_size;
962-
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
960+
int pad_h = (patch_size - x->ne[0] % patch_size) % patch_size;
961+
int pad_w = (patch_size - x->ne[1] % patch_size) % patch_size;
963962

964963
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
965-
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
966-
int64_t patchified_img_size = img->ne[1];
964+
ggml_tensor* img = NULL; // [N, h*w, C * patch_size * patch_size]
965+
int64_t patchified_img_size;
966+
for (auto& x : imgs) {
967+
int pad_h = (patch_size - x->ne[0] % patch_size) % patch_size;
968+
int pad_w = (patch_size - x->ne[1] % patch_size) % patch_size;
969+
ggml_tensor* pad_x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0);
970+
pad_x = patchify(ctx, pad_x, patch_size);
971+
if (img) {
972+
img = ggml_concat(ctx, img, pad_x, 1);
973+
} else {
974+
img = pad_x;
975+
patchified_img_size = img->ne[1];
976+
}
977+
}
967978
if (version == VERSION_FLUX_FILL) {
968979
GGML_ASSERT(c_concat != NULL);
969980
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
@@ -999,10 +1010,6 @@ namespace Flux {
9991010
control = patchify(ctx, control, patch_size);
10001011

10011012
img = ggml_concat(ctx, img, control, 0);
1002-
} else if (kontext_concat && c_concat != NULL) {
1003-
ggml_tensor* kontext = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
1004-
kontext = patchify(ctx, kontext, patch_size);
1005-
img = ggml_concat(ctx, img, kontext, 1);
10061013
}
10071014

10081015
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size]
@@ -1097,8 +1104,8 @@ namespace Flux {
10971104
struct ggml_tensor* c_concat,
10981105
struct ggml_tensor* y,
10991106
struct ggml_tensor* guidance,
1100-
bool kontext_concat = false,
1101-
std::vector<int> skip_layers = std::vector<int>()) {
1107+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
1108+
std::vector<int> skip_layers = std::vector<int>()) {
11021109
GGML_ASSERT(x->ne[3] == 1);
11031110
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
11041111

@@ -1109,6 +1116,9 @@ namespace Flux {
11091116
if (c_concat != NULL) {
11101117
c_concat = to_backend(c_concat);
11111118
}
1119+
for (auto &img : kontext_imgs){
1120+
img = to_backend(img);
1121+
}
11121122
if (flux_params.is_chroma) {
11131123
const char* SD_CHROMA_ENABLE_GUIDANCE = getenv("SD_CHROMA_ENABLE_GUIDANCE");
11141124
bool disable_guidance = true;
@@ -1148,11 +1158,8 @@ namespace Flux {
11481158
if (flux_params.guidance_embed || flux_params.is_chroma) {
11491159
guidance = to_backend(guidance);
11501160
}
1151-
1152-
std::vector<struct ggml_tensor*> imgs{x};
1153-
if (kontext_concat && c_concat != NULL) {
1154-
imgs.push_back(c_concat);
1155-
}
1161+
auto imgs = kontext_imgs;
1162+
imgs.insert(imgs.begin(), x);
11561163

11571164
pe_vec = flux.gen_pe(imgs, context, 2, flux_params.theta, flux_params.axes_dim);
11581165
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
@@ -1175,14 +1182,13 @@ namespace Flux {
11751182
// }
11761183

11771184
struct ggml_tensor* out = flux.forward(compute_ctx,
1178-
x,
1185+
imgs,
11791186
timesteps,
11801187
context,
11811188
c_concat,
11821189
y,
11831190
guidance,
11841191
pe,
1185-
kontext_concat,
11861192
precompute_arange,
11871193
skip_layers,
11881194
version);
@@ -1199,17 +1205,17 @@ namespace Flux {
11991205
struct ggml_tensor* c_concat,
12001206
struct ggml_tensor* y,
12011207
struct ggml_tensor* guidance,
1202-
bool kontext_concat = false,
1203-
struct ggml_tensor** output = NULL,
1204-
struct ggml_context* output_ctx = NULL,
1205-
std::vector<int> skip_layers = std::vector<int>()) {
1208+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
1209+
struct ggml_tensor** output = NULL,
1210+
struct ggml_context* output_ctx = NULL,
1211+
std::vector<int> skip_layers = std::vector<int>()) {
12061212
// x: [N, in_channels, h, w]
12071213
// timesteps: [N, ]
12081214
// context: [N, max_position, hidden_size]
12091215
// y: [N, adm_in_channels] or [1, adm_in_channels]
12101216
// guidance: [N, ]
12111217
auto get_graph = [&]() -> struct ggml_cgraph* {
1212-
return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_concat, skip_layers);
1218+
return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_imgs, skip_layers);
12131219
};
12141220

12151221
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@@ -1249,7 +1255,7 @@ namespace Flux {
12491255
struct ggml_tensor* out = NULL;
12501256

12511257
int t0 = ggml_time_ms();
1252-
compute(8, x, timesteps, context, NULL, y, guidance, false, &out, work_ctx);
1258+
compute(8, x, timesteps, context, NULL, y, guidance, std::vector<struct ggml_tensor*>(), &out, work_ctx);
12531259
int t1 = ggml_time_ms();
12541260

12551261
LOG_DEBUG("flux test done in %dms", t1 - t0);

ggml

Submodule ggml updated from 988abe2 to 9e4bee1

0 commit comments

Comments
 (0)