@@ -1004,6 +1004,7 @@ struct FluxCLIPEmbedder : public Conditioner {
1004
1004
T5UniGramTokenizer t5_tokenizer;
1005
1005
std::shared_ptr<CLIPTextModelRunner> clip_l;
1006
1006
std::shared_ptr<T5Runner> t5;
1007
+ size_t chunk_len = 256 ;
1007
1008
1008
1009
FluxCLIPEmbedder (ggml_backend_t backend,
1009
1010
std::map<std::string, enum ggml_type>& tensor_types,
@@ -1109,7 +1110,6 @@ struct FluxCLIPEmbedder : public Conditioner {
1109
1110
struct ggml_tensor * pooled = NULL ; // [768,]
1110
1111
std::vector<float > hidden_states_vec;
1111
1112
1112
- size_t chunk_len = 256 ;
1113
1113
size_t chunk_count = t5_tokens.size () / chunk_len;
1114
1114
for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
1115
1115
// clip_l
@@ -1196,7 +1196,7 @@ struct FluxCLIPEmbedder : public Conditioner {
1196
1196
int height,
1197
1197
int adm_in_channels = -1 ,
1198
1198
bool force_zero_embeddings = false ) {
1199
- auto tokens_and_weights = tokenize (text, 256 , true );
1199
+ auto tokens_and_weights = tokenize (text, chunk_len , true );
1200
1200
return get_learned_condition_common (work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1201
1201
}
1202
1202
@@ -1221,6 +1221,7 @@ struct FluxCLIPEmbedder : public Conditioner {
1221
1221
struct PixArtCLIPEmbedder : public Conditioner {
1222
1222
T5UniGramTokenizer t5_tokenizer;
1223
1223
std::shared_ptr<T5Runner> t5;
1224
+ size_t chunk_len = 512 ;
1224
1225
1225
1226
PixArtCLIPEmbedder (ggml_backend_t backend,
1226
1227
std::map<std::string, enum ggml_type>& tensor_types,
@@ -1304,8 +1305,18 @@ struct PixArtCLIPEmbedder : public Conditioner {
1304
1305
1305
1306
std::vector<float > hidden_states_vec;
1306
1307
1307
- size_t chunk_len = 256 ;
1308
1308
size_t chunk_count = t5_tokens.size () / chunk_len;
1309
+
1310
+ bool use_mask = true ;
1311
+ const char * SD_CHROMA_USE_T5_MASK = getenv (" SD_CHROMA_USE_T5_MASK" );
1312
+ if (SD_CHROMA_USE_T5_MASK != nullptr ) {
1313
+ std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
1314
+ if (sd_chroma_use_t5_mask_str == " OFF" || sd_chroma_use_t5_mask_str == " FALSE" ) {
1315
+ use_mask = false ;
1316
+ } else if (sd_chroma_use_t5_mask_str != " ON" && sd_chroma_use_t5_mask_str != " TRUE" ) {
1317
+ LOG_WARN (" SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\" ON\" ). (Expected \" ON\" /\" TRUE\" or\" OFF\" /\" FALSE\" , got \" %s\" )" , SD_CHROMA_USE_T5_MASK);
1318
+ }
1319
+ }
1309
1320
for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
1310
1321
// t5
1311
1322
std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
@@ -1316,17 +1327,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
1316
1327
t5_attn_mask_vec.begin () + (chunk_idx + 1 ) * chunk_len);
1317
1328
1318
1329
auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1319
- auto t5_attn_mask_chunk = vector_to_ggml_tensor (work_ctx, chunk_mask);
1320
-
1321
- const char * SD_CHROMA_USE_T5_MASK = getenv (" SD_CHROMA_USE_T5_MASK" );
1322
- if (SD_CHROMA_USE_T5_MASK != nullptr ) {
1323
- std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
1324
- if (sd_chroma_use_t5_mask_str == " OFF" || sd_chroma_use_t5_mask_str == " FALSE" ) {
1325
- t5_attn_mask_chunk = NULL ;
1326
- } else if (sd_chroma_use_t5_mask_str != " ON" && sd_chroma_use_t5_mask_str != " TRUE" ) {
1327
- LOG_WARN (" SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\" ON\" ). (Expected \" ON\" /\" TRUE\" or\" OFF\" /\" FALSE\" , got \" %s\" )" , SD_CHROMA_USE_T5_MASK);
1328
- }
1329
- }
1330
+ auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor (work_ctx, chunk_mask) : NULL ;
1330
1331
1331
1332
t5->compute (n_threads,
1332
1333
input_ids,
@@ -1384,7 +1385,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
1384
1385
int height,
1385
1386
int adm_in_channels = -1 ,
1386
1387
bool force_zero_embeddings = false ) {
1387
- auto tokens_and_weights = tokenize (text, 512 , true );
1388
+ auto tokens_and_weights = tokenize (text, chunk_len , true );
1388
1389
return get_learned_condition_common (work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1389
1390
}
1390
1391
0 commit comments