@@ -930,14 +930,13 @@ namespace Flux {
930
930
}
931
931
932
932
struct ggml_tensor * forward (struct ggml_context * ctx,
933
- struct ggml_tensor * x ,
933
+ std::vector< struct ggml_tensor *> imgs ,
934
934
struct ggml_tensor * timestep,
935
935
struct ggml_tensor * context,
936
936
struct ggml_tensor * c_concat,
937
937
struct ggml_tensor * y,
938
938
struct ggml_tensor * guidance,
939
939
struct ggml_tensor * pe,
940
- bool kontext_concat = false ,
941
940
struct ggml_tensor * arange = NULL ,
942
941
std::vector<int > skip_layers = std::vector<int >(),
943
942
SDVersion version = VERSION_FLUX) {
@@ -951,19 +950,31 @@ namespace Flux {
951
950
// pe: (L, d_head/2, 2, 2)
952
951
// return: (N, C, H, W)
953
952
953
+ auto x = imgs[0 ];
954
954
GGML_ASSERT (x->ne [3 ] == 1 );
955
955
956
956
int64_t W = x->ne [0 ];
957
957
int64_t H = x->ne [1 ];
958
958
int64_t C = x->ne [2 ];
959
959
int64_t patch_size = 2 ;
960
- int pad_h = (patch_size - H % patch_size) % patch_size;
961
- int pad_w = (patch_size - W % patch_size) % patch_size;
962
- x = ggml_pad (ctx, x, pad_w, pad_h, 0 , 0 ); // [N, C, H + pad_h, W + pad_w]
960
+ int pad_h = (patch_size - x->ne [0 ] % patch_size) % patch_size;
961
+ int pad_w = (patch_size - x->ne [1 ] % patch_size) % patch_size;
963
962
964
963
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
965
- auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
966
- int64_t patchified_img_size = img->ne [1 ];
964
+ ggml_tensor* img = NULL ; // [N, h*w, C * patch_size * patch_size]
965
+ int64_t patchified_img_size;
966
+ for (auto & x : imgs) {
967
+ int pad_h = (patch_size - x->ne [0 ] % patch_size) % patch_size;
968
+ int pad_w = (patch_size - x->ne [1 ] % patch_size) % patch_size;
969
+ ggml_tensor* pad_x = ggml_pad (ctx, x, pad_w, pad_h, 0 , 0 );
970
+ pad_x = patchify (ctx, pad_x, patch_size);
971
+ if (img) {
972
+ img = ggml_concat (ctx, img, pad_x, 1 );
973
+ } else {
974
+ img = pad_x;
975
+ patchified_img_size = img->ne [1 ];
976
+ }
977
+ }
967
978
if (version == VERSION_FLUX_FILL) {
968
979
GGML_ASSERT (c_concat != NULL );
969
980
ggml_tensor* masked = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], 0 );
@@ -999,10 +1010,6 @@ namespace Flux {
999
1010
control = patchify (ctx, control, patch_size);
1000
1011
1001
1012
img = ggml_concat (ctx, img, control, 0 );
1002
- } else if (kontext_concat && c_concat != NULL ) {
1003
- ggml_tensor* kontext = ggml_pad (ctx, c_concat, pad_w, pad_h, 0 , 0 );
1004
- kontext = patchify (ctx, kontext, patch_size);
1005
- img = ggml_concat (ctx, img, kontext, 1 );
1006
1013
}
1007
1014
1008
1015
auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size]
@@ -1097,8 +1104,8 @@ namespace Flux {
1097
1104
struct ggml_tensor * c_concat,
1098
1105
struct ggml_tensor * y,
1099
1106
struct ggml_tensor * guidance,
1100
- bool kontext_concat = false ,
1101
- std::vector<int > skip_layers = std::vector<int >()) {
1107
+ std::vector< struct ggml_tensor *> kontext_imgs = std::vector< struct ggml_tensor *>() ,
1108
+ std::vector<int> skip_layers = std::vector<int>()) {
1102
1109
GGML_ASSERT (x->ne [3 ] == 1 );
1103
1110
struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
1104
1111
@@ -1109,6 +1116,9 @@ namespace Flux {
1109
1116
if (c_concat != NULL ) {
1110
1117
c_concat = to_backend (c_concat);
1111
1118
}
1119
+ for (auto &img : kontext_imgs){
1120
+ img = to_backend (img);
1121
+ }
1112
1122
if (flux_params.is_chroma ) {
1113
1123
const char * SD_CHROMA_ENABLE_GUIDANCE = getenv (" SD_CHROMA_ENABLE_GUIDANCE" );
1114
1124
bool disable_guidance = true ;
@@ -1148,11 +1158,8 @@ namespace Flux {
1148
1158
if (flux_params.guidance_embed || flux_params.is_chroma ) {
1149
1159
guidance = to_backend (guidance);
1150
1160
}
1151
-
1152
- std::vector<struct ggml_tensor *> imgs{x};
1153
- if (kontext_concat && c_concat != NULL ) {
1154
- imgs.push_back (c_concat);
1155
- }
1161
+ auto imgs = kontext_imgs;
1162
+ imgs.insert (imgs.begin (), x);
1156
1163
1157
1164
pe_vec = flux.gen_pe (imgs, context, 2 , flux_params.theta , flux_params.axes_dim );
1158
1165
int pos_len = pe_vec.size () / flux_params.axes_dim_sum / 2 ;
@@ -1175,14 +1182,13 @@ namespace Flux {
1175
1182
// }
1176
1183
1177
1184
struct ggml_tensor * out = flux.forward (compute_ctx,
1178
- x ,
1185
+ imgs ,
1179
1186
timesteps,
1180
1187
context,
1181
1188
c_concat,
1182
1189
y,
1183
1190
guidance,
1184
1191
pe,
1185
- kontext_concat,
1186
1192
precompute_arange,
1187
1193
skip_layers,
1188
1194
version);
@@ -1199,17 +1205,17 @@ namespace Flux {
1199
1205
struct ggml_tensor * c_concat,
1200
1206
struct ggml_tensor * y,
1201
1207
struct ggml_tensor * guidance,
1202
- bool kontext_concat = false ,
1203
- struct ggml_tensor ** output = NULL ,
1204
- struct ggml_context * output_ctx = NULL ,
1205
- std::vector<int > skip_layers = std::vector<int >()) {
1208
+ std::vector< struct ggml_tensor *> kontext_imgs = std::vector< struct ggml_tensor *>() ,
1209
+ struct ggml_tensor** output = NULL,
1210
+ struct ggml_context* output_ctx = NULL,
1211
+ std::vector<int> skip_layers = std::vector<int>()) {
1206
1212
// x: [N, in_channels, h, w]
1207
1213
// timesteps: [N, ]
1208
1214
// context: [N, max_position, hidden_size]
1209
1215
// y: [N, adm_in_channels] or [1, adm_in_channels]
1210
1216
// guidance: [N, ]
1211
1217
auto get_graph = [&]() -> struct ggml_cgraph * {
1212
- return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_concat , skip_layers);
1218
+ return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_imgs , skip_layers);
1213
1219
};
1214
1220
1215
1221
return GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
@@ -1249,7 +1255,7 @@ namespace Flux {
1249
1255
struct ggml_tensor * out = NULL ;
1250
1256
1251
1257
int t0 = ggml_time_ms ();
1252
- compute (8 , x, timesteps, context, NULL , y, guidance, false , &out, work_ctx);
1258
+ compute (8 , x, timesteps, context, NULL , y, guidance, std::vector< struct ggml_tensor *>() , &out, work_ctx);
1253
1259
int t1 = ggml_time_ms ();
1254
1260
1255
1261
LOG_DEBUG (" flux test done in %dms" , t1 - t0);
0 commit comments