Skip to content

Commit 4fdedd5

Browse files
committed
Chroma: Fix t5 chunk length
1 parent 42e217d commit 4fdedd5

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

conditioner.hpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,7 @@ struct FluxCLIPEmbedder : public Conditioner {
10041004
T5UniGramTokenizer t5_tokenizer;
10051005
std::shared_ptr<CLIPTextModelRunner> clip_l;
10061006
std::shared_ptr<T5Runner> t5;
1007+
size_t chunk_len = 256;
10071008

10081009
FluxCLIPEmbedder(ggml_backend_t backend,
10091010
std::map<std::string, enum ggml_type>& tensor_types,
@@ -1109,7 +1110,6 @@ struct FluxCLIPEmbedder : public Conditioner {
11091110
struct ggml_tensor* pooled = NULL; // [768,]
11101111
std::vector<float> hidden_states_vec;
11111112

1112-
size_t chunk_len = 256;
11131113
size_t chunk_count = t5_tokens.size() / chunk_len;
11141114
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
11151115
// clip_l
@@ -1196,7 +1196,7 @@ struct FluxCLIPEmbedder : public Conditioner {
11961196
int height,
11971197
int adm_in_channels = -1,
11981198
bool force_zero_embeddings = false) {
1199-
auto tokens_and_weights = tokenize(text, 256, true);
1199+
auto tokens_and_weights = tokenize(text, chunk_len, true);
12001200
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
12011201
}
12021202

@@ -1221,6 +1221,7 @@ struct FluxCLIPEmbedder : public Conditioner {
12211221
struct PixArtCLIPEmbedder : public Conditioner {
12221222
T5UniGramTokenizer t5_tokenizer;
12231223
std::shared_ptr<T5Runner> t5;
1224+
size_t chunk_len = 512;
12241225

12251226
PixArtCLIPEmbedder(ggml_backend_t backend,
12261227
std::map<std::string, enum ggml_type>& tensor_types,
@@ -1304,8 +1305,18 @@ struct PixArtCLIPEmbedder : public Conditioner {
13041305

13051306
std::vector<float> hidden_states_vec;
13061307

1307-
size_t chunk_len = 256;
13081308
size_t chunk_count = t5_tokens.size() / chunk_len;
1309+
1310+
bool use_mask = true;
1311+
const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK");
1312+
if (SD_CHROMA_USE_T5_MASK != nullptr) {
1313+
std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
1314+
if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") {
1315+
use_mask = false;
1316+
} else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") {
1317+
LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK);
1318+
}
1319+
}
13091320
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
13101321
// t5
13111322
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
@@ -1316,17 +1327,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
13161327
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
13171328

13181329
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1319-
auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask);
1320-
1321-
const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK");
1322-
if (SD_CHROMA_USE_T5_MASK != nullptr) {
1323-
std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
1324-
if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") {
1325-
t5_attn_mask_chunk = NULL;
1326-
} else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") {
1327-
LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK);
1328-
}
1329-
}
1330+
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
13301331

13311332
t5->compute(n_threads,
13321333
input_ids,
@@ -1384,7 +1385,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
13841385
int height,
13851386
int adm_in_channels = -1,
13861387
bool force_zero_embeddings = false) {
1387-
auto tokens_and_weights = tokenize(text, 512, true);
1388+
auto tokens_and_weights = tokenize(text, chunk_len, true);
13881389
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
13891390
}
13901391

0 commit comments

Comments
 (0)