Skip to content

Commit e75661d

Browse files
committed
apg: add SD_LOG_CFG_DELTA_NORM
1 parent 573a091 commit e75661d

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

stable-diffusion.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,18 @@ class StableDiffusionGGML {
11961196

11971197
float* deltas = vec_denoised;
11981198

1199-
// https://arxiv.org/pdf/2410.02416
1199+
// APG: https://arxiv.org/pdf/2410.02416
1200+
1201+
bool log_cfg_norm = false;
1202+
const char* SD_LOG_CFG_DELTA_NORM = getenv("SD_LOG_CFG_DELTA_NORM");
1203+
if (SD_LOG_CFG_DELTA_NORM != nullptr) {
1204+
std::string sd_log_cfg_norm_str = SD_LOG_CFG_DELTA_NORM;
1205+
if (sd_log_cfg_norm_str == "ON" || sd_log_cfg_norm_str == "TRUE") {
1206+
log_cfg_norm = true;
1207+
} else if (sd_log_cfg_norm_str != "OFF" && sd_log_cfg_norm_str != "FALSE") {
1208+
LOG_WARN("SD_LOG_CFG_DELTA_NORM environment variable has unexpected value. Assuming default (\"OFF\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_LOG_CFG_DELTA_NORM);
1209+
}
1210+
}
12001211
float apg_scale_factor = 1.;
12011212
float diff_norm = 0;
12021213
float cond_norm_sq = 0;
@@ -1224,7 +1235,7 @@ class StableDiffusionGGML {
12241235
delta += guidance.apg.momentum * apg_momentum_buffer[i];
12251236
apg_momentum_buffer[i] = delta;
12261237
}
1227-
if (guidance.apg.norm_treshold > 0) {
1238+
if (guidance.apg.norm_treshold > 0 || log_cfg_norm) {
12281239
diff_norm += delta * delta;
12291240
}
12301241
if (guidance.apg.eta != 1.0f) {
@@ -1233,6 +1244,9 @@ class StableDiffusionGGML {
12331244
}
12341245
deltas[i] = delta;
12351246
}
1247+
if(log_cfg_norm){
1248+
LOG_INFO("CFG Delta norm: %.2f", sqrtf(diff_norm));
1249+
}
12361250
if (guidance.apg.norm_treshold > 0) {
12371251
diff_norm = sqrtf(diff_norm);
12381252
if (guidance.apg.norm_treshold_smoothing <= 0) {

0 commit comments

Comments
 (0)