Skip to content

Commit b1cc40c

Browse files
stduhpfGreen-Skyleejet
authored
feat: add Chroma support (#696)
--------- Co-authored-by: Green Sky <[email protected]> Co-authored-by: leejet <[email protected]>
1 parent 884e23e commit b1cc40c

File tree

8 files changed

+566
-115
lines changed

8 files changed

+566
-115
lines changed

conditioner.hpp

Lines changed: 207 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ struct SD3CLIPEmbedder : public Conditioner {
747747

748748
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
749749
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
750-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
750+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
751751

752752
// for (int i = 0; i < clip_l_tokens.size(); i++) {
753753
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -902,6 +902,7 @@ struct SD3CLIPEmbedder : public Conditioner {
902902

903903
t5->compute(n_threads,
904904
input_ids,
905+
NULL,
905906
&chunk_hidden_states_t5,
906907
work_ctx);
907908
{
@@ -1004,6 +1005,7 @@ struct FluxCLIPEmbedder : public Conditioner {
10041005
T5UniGramTokenizer t5_tokenizer;
10051006
std::shared_ptr<CLIPTextModelRunner> clip_l;
10061007
std::shared_ptr<T5Runner> t5;
1008+
size_t chunk_len = 256;
10071009

10081010
FluxCLIPEmbedder(ggml_backend_t backend,
10091011
std::map<std::string, enum ggml_type>& tensor_types,
@@ -1077,7 +1079,7 @@ struct FluxCLIPEmbedder : public Conditioner {
10771079
}
10781080

10791081
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1080-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
1082+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
10811083

10821084
// for (int i = 0; i < clip_l_tokens.size(); i++) {
10831085
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1109,7 +1111,6 @@ struct FluxCLIPEmbedder : public Conditioner {
11091111
struct ggml_tensor* pooled = NULL; // [768,]
11101112
std::vector<float> hidden_states_vec;
11111113

1112-
size_t chunk_len = 256;
11131114
size_t chunk_count = t5_tokens.size() / chunk_len;
11141115
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
11151116
// clip_l
@@ -1147,6 +1148,7 @@ struct FluxCLIPEmbedder : public Conditioner {
11471148

11481149
t5->compute(n_threads,
11491150
input_ids,
1151+
NULL,
11501152
&chunk_hidden_states,
11511153
work_ctx);
11521154
{
@@ -1196,7 +1198,208 @@ struct FluxCLIPEmbedder : public Conditioner {
11961198
int height,
11971199
int adm_in_channels = -1,
11981200
bool force_zero_embeddings = false) {
1199-
auto tokens_and_weights = tokenize(text, 256, true);
1201+
auto tokens_and_weights = tokenize(text, chunk_len, true);
1202+
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1203+
}
1204+
1205+
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
1206+
int n_threads,
1207+
const std::string& text,
1208+
int clip_skip,
1209+
int width,
1210+
int height,
1211+
int num_input_imgs,
1212+
int adm_in_channels = -1,
1213+
bool force_zero_embeddings = false) {
1214+
GGML_ASSERT(0 && "Not implemented yet!");
1215+
}
1216+
1217+
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
1218+
const std::string& prompt) {
1219+
GGML_ASSERT(0 && "Not implemented yet!");
1220+
}
1221+
};
1222+
1223+
struct PixArtCLIPEmbedder : public Conditioner {
1224+
T5UniGramTokenizer t5_tokenizer;
1225+
std::shared_ptr<T5Runner> t5;
1226+
size_t chunk_len = 512;
1227+
bool use_mask = false;
1228+
int mask_pad = 1;
1229+
1230+
PixArtCLIPEmbedder(ggml_backend_t backend,
1231+
std::map<std::string, enum ggml_type>& tensor_types,
1232+
int clip_skip = -1,
1233+
bool use_mask = false,
1234+
int mask_pad = 1) : use_mask(use_mask), mask_pad(mask_pad) {
1235+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1236+
}
1237+
1238+
void set_clip_skip(int clip_skip) {
1239+
}
1240+
1241+
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1242+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1243+
}
1244+
1245+
void alloc_params_buffer() {
1246+
t5->alloc_params_buffer();
1247+
}
1248+
1249+
void free_params_buffer() {
1250+
t5->free_params_buffer();
1251+
}
1252+
1253+
size_t get_params_buffer_size() {
1254+
size_t buffer_size = 0;
1255+
1256+
buffer_size += t5->get_params_buffer_size();
1257+
1258+
return buffer_size;
1259+
}
1260+
1261+
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
1262+
size_t max_length = 0,
1263+
bool padding = false) {
1264+
auto parsed_attention = parse_prompt_attention(text);
1265+
1266+
{
1267+
std::stringstream ss;
1268+
ss << "[";
1269+
for (const auto& item : parsed_attention) {
1270+
ss << "['" << item.first << "', " << item.second << "], ";
1271+
}
1272+
ss << "]";
1273+
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
1274+
}
1275+
1276+
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
1277+
return false;
1278+
};
1279+
1280+
std::vector<int> t5_tokens;
1281+
std::vector<float> t5_weights;
1282+
std::vector<float> t5_mask;
1283+
for (const auto& item : parsed_attention) {
1284+
const std::string& curr_text = item.first;
1285+
float curr_weight = item.second;
1286+
1287+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1288+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1289+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1290+
}
1291+
1292+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1293+
1294+
return {t5_tokens, t5_weights, t5_mask};
1295+
}
1296+
1297+
void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
1298+
float* mask_data = (float*)mask->data;
1299+
int num_pad = 0;
1300+
for (int64_t i = 0; i < max_seq_length; i++) {
1301+
if (num_pad >= num_extra_padding) {
1302+
break;
1303+
}
1304+
if (std::isinf(mask_data[i])) {
1305+
mask_data[i] = 0;
1306+
++num_pad;
1307+
}
1308+
}
1309+
// LOG_DEBUG("PAD: %d", num_pad);
1310+
}
1311+
1312+
SDCondition get_learned_condition_common(ggml_context* work_ctx,
1313+
int n_threads,
1314+
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
1315+
int clip_skip,
1316+
bool force_zero_embeddings = false) {
1317+
auto& t5_tokens = std::get<0>(token_and_weights);
1318+
auto& t5_weights = std::get<1>(token_and_weights);
1319+
auto& t5_attn_mask_vec = std::get<2>(token_and_weights);
1320+
1321+
int64_t t0 = ggml_time_ms();
1322+
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
1323+
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
1324+
struct ggml_tensor* pooled = NULL; // [768,]
1325+
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,]
1326+
1327+
std::vector<float> hidden_states_vec;
1328+
1329+
size_t chunk_count = t5_tokens.size() / chunk_len;
1330+
1331+
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
1332+
// t5
1333+
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1334+
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1335+
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1336+
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1337+
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1338+
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1339+
1340+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1341+
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
1342+
1343+
t5->compute(n_threads,
1344+
input_ids,
1345+
t5_attn_mask_chunk,
1346+
&chunk_hidden_states,
1347+
work_ctx);
1348+
{
1349+
auto tensor = chunk_hidden_states;
1350+
float original_mean = ggml_tensor_mean(tensor);
1351+
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1352+
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1353+
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1354+
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1355+
value *= chunk_weights[i1];
1356+
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1357+
}
1358+
}
1359+
}
1360+
float new_mean = ggml_tensor_mean(tensor);
1361+
ggml_tensor_scale(tensor, (original_mean / new_mean));
1362+
}
1363+
1364+
int64_t t1 = ggml_time_ms();
1365+
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
1366+
if (force_zero_embeddings) {
1367+
float* vec = (float*)chunk_hidden_states->data;
1368+
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
1369+
vec[i] = 0;
1370+
}
1371+
}
1372+
1373+
hidden_states_vec.insert(hidden_states_vec.end(),
1374+
(float*)chunk_hidden_states->data,
1375+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1376+
}
1377+
1378+
if (hidden_states_vec.size() > 0) {
1379+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1380+
hidden_states = ggml_reshape_2d(work_ctx,
1381+
hidden_states,
1382+
chunk_hidden_states->ne[0],
1383+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1384+
} else {
1385+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1386+
ggml_set_f32(hidden_states, 0.f);
1387+
}
1388+
1389+
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
1390+
1391+
return SDCondition(hidden_states, t5_attn_mask, NULL);
1392+
}
1393+
1394+
SDCondition get_learned_condition(ggml_context* work_ctx,
1395+
int n_threads,
1396+
const std::string& text,
1397+
int clip_skip,
1398+
int width,
1399+
int height,
1400+
int adm_in_channels = -1,
1401+
bool force_zero_embeddings = false) {
1402+
auto tokens_and_weights = tokenize(text, chunk_len, true);
12001403
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
12011404
}
12021405

diffusion_model.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ struct FluxModel : public DiffusionModel {
137137
FluxModel(ggml_backend_t backend,
138138
std::map<std::string, enum ggml_type>& tensor_types,
139139
SDVersion version = VERSION_FLUX,
140-
bool flash_attn = false)
141-
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
140+
bool flash_attn = false,
141+
bool use_mask = false)
142+
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
142143
}
143144

144145
void alloc_params_buffer() {

examples/cli/main.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ struct SDParams {
132132
float slg_scale = 0.f;
133133
float skip_layer_start = 0.01f;
134134
float skip_layer_end = 0.2f;
135+
136+
bool chroma_use_dit_mask = true;
137+
bool chroma_use_t5_mask = false;
138+
int chroma_t5_mask_pad = 1;
135139
};
136140

137141
void print_params(SDParams params) {
@@ -185,6 +189,9 @@ void print_params(SDParams params) {
185189
printf(" batch_count: %d\n", params.batch_count);
186190
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
187191
printf(" upscale_repeats: %d\n", params.upscale_repeats);
192+
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
193+
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
194+
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
188195
}
189196

190197
void print_usage(int argc, const char* argv[]) {
@@ -252,6 +259,9 @@ void print_usage(int argc, const char* argv[]) {
252259
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
253260
printf(" --canny apply canny preprocessor (edge detection)\n");
254261
printf(" --color colors the logging tags according to level\n");
262+
printf(" --chroma-disable-dit-mask disable dit mask for chroma\n");
263+
printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n");
264+
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
255265
printf(" -v, --verbose print extra info\n");
256266
}
257267

@@ -643,6 +653,16 @@ void parse_args(int argc, const char** argv, SDParams& params) {
643653
break;
644654
}
645655
params.ref_image_paths.push_back(argv[i]);
656+
} else if (arg == "chroma-disable-dit-mask") {
657+
params.chroma_use_dit_mask = false;
658+
} else if (arg == "--chroma-use-t5-mask") {
659+
params.chroma_use_t5_mask = true;
660+
} else if (arg == "--chroma-t5-mask-pad") {
661+
if (++i >= argc) {
662+
invalid_arg = true;
663+
break;
664+
}
665+
params.chroma_t5_mask_pad = std::stoi(argv[i]);
646666
} else {
647667
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
648668
print_usage(argc, argv);
@@ -952,7 +972,10 @@ int main(int argc, const char* argv[]) {
952972
params.clip_on_cpu,
953973
params.control_net_cpu,
954974
params.vae_on_cpu,
955-
params.diffusion_flash_attn);
975+
params.diffusion_flash_attn,
976+
params.chroma_use_dit_mask,
977+
params.chroma_use_t5_mask,
978+
params.chroma_t5_mask_pad);
956979

957980
if (sd_ctx == NULL) {
958981
printf("new_sd_ctx_t failed\n");

0 commit comments

Comments
 (0)