@@ -747,7 +747,7 @@ struct SD3CLIPEmbedder : public Conditioner {
747
747
748
748
clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
749
749
clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
750
- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
750
+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
751
751
752
752
// for (int i = 0; i < clip_l_tokens.size(); i++) {
753
753
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -902,6 +902,7 @@ struct SD3CLIPEmbedder : public Conditioner {
902
902
903
903
t5->compute (n_threads,
904
904
input_ids,
905
+ NULL ,
905
906
&chunk_hidden_states_t5,
906
907
work_ctx);
907
908
{
@@ -1004,6 +1005,7 @@ struct FluxCLIPEmbedder : public Conditioner {
1004
1005
T5UniGramTokenizer t5_tokenizer;
1005
1006
std::shared_ptr<CLIPTextModelRunner> clip_l;
1006
1007
std::shared_ptr<T5Runner> t5;
1008
+ size_t chunk_len = 256 ;
1007
1009
1008
1010
FluxCLIPEmbedder (ggml_backend_t backend,
1009
1011
std::map<std::string, enum ggml_type>& tensor_types,
@@ -1077,7 +1079,7 @@ struct FluxCLIPEmbedder : public Conditioner {
1077
1079
}
1078
1080
1079
1081
clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1080
- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
1082
+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
1081
1083
1082
1084
// for (int i = 0; i < clip_l_tokens.size(); i++) {
1083
1085
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1109,7 +1111,6 @@ struct FluxCLIPEmbedder : public Conditioner {
1109
1111
struct ggml_tensor * pooled = NULL ; // [768,]
1110
1112
std::vector<float > hidden_states_vec;
1111
1113
1112
- size_t chunk_len = 256 ;
1113
1114
size_t chunk_count = t5_tokens.size () / chunk_len;
1114
1115
for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
1115
1116
// clip_l
@@ -1147,6 +1148,7 @@ struct FluxCLIPEmbedder : public Conditioner {
1147
1148
1148
1149
t5->compute (n_threads,
1149
1150
input_ids,
1151
+ NULL ,
1150
1152
&chunk_hidden_states,
1151
1153
work_ctx);
1152
1154
{
@@ -1196,7 +1198,208 @@ struct FluxCLIPEmbedder : public Conditioner {
1196
1198
int height,
1197
1199
int adm_in_channels = -1 ,
1198
1200
bool force_zero_embeddings = false ) {
1199
- auto tokens_and_weights = tokenize (text, 256 , true );
1201
+ auto tokens_and_weights = tokenize (text, chunk_len, true );
1202
+ return get_learned_condition_common (work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1203
+ }
1204
+
1205
+ std::tuple<SDCondition, std::vector<bool >> get_learned_condition_with_trigger (ggml_context* work_ctx,
1206
+ int n_threads,
1207
+ const std::string& text,
1208
+ int clip_skip,
1209
+ int width,
1210
+ int height,
1211
+ int num_input_imgs,
1212
+ int adm_in_channels = -1 ,
1213
+ bool force_zero_embeddings = false ) {
1214
+ GGML_ASSERT (0 && " Not implemented yet!" );
1215
+ }
1216
+
1217
+ std::string remove_trigger_from_prompt (ggml_context* work_ctx,
1218
+ const std::string& prompt) {
1219
+ GGML_ASSERT (0 && " Not implemented yet!" );
1220
+ }
1221
+ };
1222
+
1223
+ struct PixArtCLIPEmbedder : public Conditioner {
1224
+ T5UniGramTokenizer t5_tokenizer;
1225
+ std::shared_ptr<T5Runner> t5;
1226
+ size_t chunk_len = 512 ;
1227
+ bool use_mask = false ;
1228
+ int mask_pad = 1 ;
1229
+
1230
+ PixArtCLIPEmbedder (ggml_backend_t backend,
1231
+ std::map<std::string, enum ggml_type>& tensor_types,
1232
+ int clip_skip = -1 ,
1233
+ bool use_mask = false ,
1234
+ int mask_pad = 1 ) : use_mask(use_mask), mask_pad(mask_pad) {
1235
+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1236
+ }
1237
+
1238
+ void set_clip_skip (int clip_skip) {
1239
+ }
1240
+
1241
+ void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
1242
+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1243
+ }
1244
+
1245
+ void alloc_params_buffer () {
1246
+ t5->alloc_params_buffer ();
1247
+ }
1248
+
1249
+ void free_params_buffer () {
1250
+ t5->free_params_buffer ();
1251
+ }
1252
+
1253
+ size_t get_params_buffer_size () {
1254
+ size_t buffer_size = 0 ;
1255
+
1256
+ buffer_size += t5->get_params_buffer_size ();
1257
+
1258
+ return buffer_size;
1259
+ }
1260
+
1261
+ std::tuple<std::vector<int >, std::vector<float >, std::vector<float >> tokenize (std::string text,
1262
+ size_t max_length = 0 ,
1263
+ bool padding = false ) {
1264
+ auto parsed_attention = parse_prompt_attention (text);
1265
+
1266
+ {
1267
+ std::stringstream ss;
1268
+ ss << " [" ;
1269
+ for (const auto & item : parsed_attention) {
1270
+ ss << " ['" << item.first << " ', " << item.second << " ], " ;
1271
+ }
1272
+ ss << " ]" ;
1273
+ LOG_DEBUG (" parse '%s' to %s" , text.c_str (), ss.str ().c_str ());
1274
+ }
1275
+
1276
+ auto on_new_token_cb = [&](std::string& str, std::vector<int32_t >& bpe_tokens) -> bool {
1277
+ return false ;
1278
+ };
1279
+
1280
+ std::vector<int > t5_tokens;
1281
+ std::vector<float > t5_weights;
1282
+ std::vector<float > t5_mask;
1283
+ for (const auto & item : parsed_attention) {
1284
+ const std::string& curr_text = item.first ;
1285
+ float curr_weight = item.second ;
1286
+
1287
+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
1288
+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1289
+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1290
+ }
1291
+
1292
+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, &t5_mask, max_length, padding);
1293
+
1294
+ return {t5_tokens, t5_weights, t5_mask};
1295
+ }
1296
+
1297
+ void modify_mask_to_attend_padding (struct ggml_tensor * mask, int max_seq_length, int num_extra_padding = 8 ) {
1298
+ float * mask_data = (float *)mask->data ;
1299
+ int num_pad = 0 ;
1300
+ for (int64_t i = 0 ; i < max_seq_length; i++) {
1301
+ if (num_pad >= num_extra_padding) {
1302
+ break ;
1303
+ }
1304
+ if (std::isinf (mask_data[i])) {
1305
+ mask_data[i] = 0 ;
1306
+ ++num_pad;
1307
+ }
1308
+ }
1309
+ // LOG_DEBUG("PAD: %d", num_pad);
1310
+ }
1311
+
1312
+ SDCondition get_learned_condition_common (ggml_context* work_ctx,
1313
+ int n_threads,
1314
+ std::tuple<std::vector<int >, std::vector<float >, std::vector<float >> token_and_weights,
1315
+ int clip_skip,
1316
+ bool force_zero_embeddings = false ) {
1317
+ auto & t5_tokens = std::get<0 >(token_and_weights);
1318
+ auto & t5_weights = std::get<1 >(token_and_weights);
1319
+ auto & t5_attn_mask_vec = std::get<2 >(token_and_weights);
1320
+
1321
+ int64_t t0 = ggml_time_ms ();
1322
+ struct ggml_tensor * hidden_states = NULL ; // [N, n_token, 4096]
1323
+ struct ggml_tensor * chunk_hidden_states = NULL ; // [n_token, 4096]
1324
+ struct ggml_tensor * pooled = NULL ; // [768,]
1325
+ struct ggml_tensor * t5_attn_mask = vector_to_ggml_tensor (work_ctx, t5_attn_mask_vec); // [768,]
1326
+
1327
+ std::vector<float > hidden_states_vec;
1328
+
1329
+ size_t chunk_count = t5_tokens.size () / chunk_len;
1330
+
1331
+ for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
1332
+ // t5
1333
+ std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
1334
+ t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
1335
+ std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
1336
+ t5_weights.begin () + (chunk_idx + 1 ) * chunk_len);
1337
+ std::vector<float > chunk_mask (t5_attn_mask_vec.begin () + chunk_idx * chunk_len,
1338
+ t5_attn_mask_vec.begin () + (chunk_idx + 1 ) * chunk_len);
1339
+
1340
+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1341
+ auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor (work_ctx, chunk_mask) : NULL ;
1342
+
1343
+ t5->compute (n_threads,
1344
+ input_ids,
1345
+ t5_attn_mask_chunk,
1346
+ &chunk_hidden_states,
1347
+ work_ctx);
1348
+ {
1349
+ auto tensor = chunk_hidden_states;
1350
+ float original_mean = ggml_tensor_mean (tensor);
1351
+ for (int i2 = 0 ; i2 < tensor->ne [2 ]; i2++) {
1352
+ for (int i1 = 0 ; i1 < tensor->ne [1 ]; i1++) {
1353
+ for (int i0 = 0 ; i0 < tensor->ne [0 ]; i0++) {
1354
+ float value = ggml_tensor_get_f32 (tensor, i0, i1, i2);
1355
+ value *= chunk_weights[i1];
1356
+ ggml_tensor_set_f32 (tensor, value, i0, i1, i2);
1357
+ }
1358
+ }
1359
+ }
1360
+ float new_mean = ggml_tensor_mean (tensor);
1361
+ ggml_tensor_scale (tensor, (original_mean / new_mean));
1362
+ }
1363
+
1364
+ int64_t t1 = ggml_time_ms ();
1365
+ LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
1366
+ if (force_zero_embeddings) {
1367
+ float * vec = (float *)chunk_hidden_states->data ;
1368
+ for (int i = 0 ; i < ggml_nelements (chunk_hidden_states); i++) {
1369
+ vec[i] = 0 ;
1370
+ }
1371
+ }
1372
+
1373
+ hidden_states_vec.insert (hidden_states_vec.end (),
1374
+ (float *)chunk_hidden_states->data ,
1375
+ ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1376
+ }
1377
+
1378
+ if (hidden_states_vec.size () > 0 ) {
1379
+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1380
+ hidden_states = ggml_reshape_2d (work_ctx,
1381
+ hidden_states,
1382
+ chunk_hidden_states->ne [0 ],
1383
+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1384
+ } else {
1385
+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 256 );
1386
+ ggml_set_f32 (hidden_states, 0 .f );
1387
+ }
1388
+
1389
+ modify_mask_to_attend_padding (t5_attn_mask, ggml_nelements (t5_attn_mask), mask_pad);
1390
+
1391
+ return SDCondition (hidden_states, t5_attn_mask, NULL );
1392
+ }
1393
+
1394
+ SDCondition get_learned_condition (ggml_context* work_ctx,
1395
+ int n_threads,
1396
+ const std::string& text,
1397
+ int clip_skip,
1398
+ int width,
1399
+ int height,
1400
+ int adm_in_channels = -1 ,
1401
+ bool force_zero_embeddings = false ) {
1402
+ auto tokens_and_weights = tokenize (text, chunk_len, true );
1200
1403
return get_learned_condition_common (work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1201
1404
}
1202
1405
0 commit comments