diff --git a/diffusion_model.hpp b/diffusion_model.hpp index ee4d88f0c..c1421fc82 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -16,6 +16,7 @@ struct DiffusionModel { int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, + struct ggml_context* persistent_work_ctx = NULL, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) = 0; @@ -33,8 +34,14 @@ struct UNetModel : public DiffusionModel { UNetModel(ggml_backend_t backend, std::map& tensor_types, SDVersion version = VERSION_SD1, - bool flash_attn = false) - : unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) { + bool flash_attn = false, + // DeepCache parameters + int dc_cache_interval = 0, + int dc_cache_depth = 3, + int dc_start_steps = 0, + int dc_end_steps = 9999) + : unet(backend, tensor_types, "model.diffusion_model", version, flash_attn, + dc_cache_interval, dc_cache_depth, dc_start_steps, dc_end_steps) { } void alloc_params_buffer() { @@ -71,13 +78,14 @@ struct UNetModel : public DiffusionModel { int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, + struct ggml_context* persistent_work_ctx = NULL, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { (void)skip_layers; // SLG doesn't work with UNet models - return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); + return unet.compute(n_threads, x, timesteps, context, c_concat, y, persistent_work_ctx, num_video_frames, controls, control_strength, output, output_ctx); } -}; +};; struct MMDiTModel : public DiffusionModel { MMDiTRunner mmdit; @@ -121,9 +129,11 @@ struct MMDiTModel : public DiffusionModel { int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, + struct ggml_context* persistent_work_ctx = NULL, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { + (void)persistent_work_ctx; // Not used by MMDiT return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); } }; @@ -172,9 +182,11 @@ struct FluxModel : public DiffusionModel { int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, + struct ggml_context* persistent_work_ctx = NULL, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { + (void)persistent_work_ctx; // Not used by Flux return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers); } }; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index af6b2bbdb..091a531f6 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -129,6 +129,12 @@ struct SDParams { float slg_scale = 0.f; float skip_layer_start = 0.01f; float skip_layer_end = 0.2f; + + // DeepCache parameters + int dc_cache_interval = 0; // 0 to disable + int dc_cache_depth = 3; + int dc_start_steps = 0; + int dc_end_steps = 9999; // Effectively all steps }; void print_params(SDParams params) { @@ -178,6 +184,10 @@ void print_params(SDParams params) { printf(" batch_count: %d\n", params.batch_count); printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); printf(" upscale_repeats: %d\n", params.upscale_repeats); + if (params.dc_cache_interval > 0) { + printf(" deepcache: interval=%d, depth=%d, start=%d, end=%d\n", + params.dc_cache_interval, params.dc_cache_depth, params.dc_start_steps, params.dc_end_steps); + } } void print_usage(int argc, const char* argv[]) { @@ -244,6 +254,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" --color Colors the logging tags according to level\n"); + printf(" --deepcache CACHE_PARAMS Enable DeepCache for UNet. CACHE_PARAMS are comma-separated: interval,depth,start_steps,end_steps. Example: \"3,3,0,1000\"\n"); printf(" -v, --verbose print extra info\n"); } @@ -629,6 +640,46 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.skip_layer_end = std::stof(argv[i]); + } else if (arg == "--deepcache") { + if (++i >= argc) { + invalid_arg = true; + break; + } + std::string dc_params_str = argv[i]; + std::vector dc_tokens; + size_t start = 0; + size_t end = dc_params_str.find(','); + while (end != std::string::npos) { + dc_tokens.push_back(dc_params_str.substr(start, end - start)); + start = end + 1; + end = dc_params_str.find(',', start); + } + dc_tokens.push_back(dc_params_str.substr(start)); + + if (dc_tokens.size() != 4) { + fprintf(stderr, "error: --deepcache requires 4 comma-separated values: interval,depth,start_steps,end_steps\n"); + exit(1); + } + try { + params.dc_cache_interval = std::stoi(dc_tokens[0]); + params.dc_cache_depth = std::stoi(dc_tokens[1]); + params.dc_start_steps = std::stoi(dc_tokens[2]); + params.dc_end_steps = std::stoi(dc_tokens[3]); + if (params.dc_cache_interval <= 0) { + fprintf(stderr, "error: deepcache interval must be > 0\n"); + exit(1); + } + if (params.dc_cache_depth < 0) { + fprintf(stderr, "error: deepcache depth must be >= 0\n"); + exit(1); + } + } catch (const std::invalid_argument& e) { + fprintf(stderr, "error: invalid number in --deepcache parameters: %s\n", e.what()); + exit(1); + } catch (const std::out_of_range& e) { + fprintf(stderr, "error: number out of range in --deepcache parameters: %s\n", e.what()); + exit(1); + } } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -900,7 +951,11 @@ int main(int argc, const char* argv[]) { params.clip_on_cpu, params.control_net_cpu, params.vae_on_cpu, - params.diffusion_flash_attn); + params.diffusion_flash_attn, + params.dc_cache_interval, + params.dc_cache_depth, + params.dc_start_steps, + params.dc_end_steps); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101f..481b5e1c0 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -112,17 +112,32 @@ class StableDiffusionGGML { std::shared_ptr denoiser = std::make_shared(); + // DeepCache parameters for UNet + int dc_cache_interval_unet_ = 0; + int dc_cache_depth_unet_ = 3; + int dc_start_steps_unet_ = 0; + int dc_end_steps_unet_ = 9999; + StableDiffusionGGML() = default; StableDiffusionGGML(int n_threads, bool vae_decode_only, bool free_params_immediately, std::string lora_model_dir, - rng_type_t rng_type) + rng_type_t rng_type, + // DeepCache parameters + int dc_cache_interval, + int dc_cache_depth, + int dc_start_steps, + int dc_end_steps) : n_threads(n_threads), vae_decode_only(vae_decode_only), free_params_immediately(free_params_immediately), - lora_model_dir(lora_model_dir) { + lora_model_dir(lora_model_dir), + dc_cache_interval_unet_(dc_cache_interval), + dc_cache_depth_unet_(dc_cache_depth), + dc_start_steps_unet_(dc_start_steps), + dc_end_steps_unet_(dc_end_steps) { if (rng_type == STD_DEFAULT_RNG) { rng = std::make_shared(); } else if (rng_type == CUDA_RNG) { @@ -342,7 +357,14 @@ class StableDiffusionGGML { } else { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version); } - diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); + LOG_DEBUG("DeepCache: StableDiffusionGGML::load_from_file. About to create UNetModel. " \ + "this->dc_cache_interval_unet_: %d, this->dc_cache_depth_unet_: %d, " \ + "this->dc_start_steps_unet_: %d, this->dc_end_steps_unet_: %d", + this->dc_cache_interval_unet_, this->dc_cache_depth_unet_, + this->dc_start_steps_unet_, this->dc_end_steps_unet_); + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn, + this->dc_cache_interval_unet_, this->dc_cache_depth_unet_, + this->dc_start_steps_unet_, this->dc_end_steps_unet_); } cond_stage_model->alloc_params_buffer(); @@ -617,8 +639,11 @@ class StableDiffusionGGML { } int64_t t0 = ggml_time_ms(); - struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, &out); + struct ggml_tensor* out = NULL; // Output tensor will be allocated by compute if output_ctx is provided + struct ggml_tensor* out_tensor = ggml_dup_tensor(work_ctx, x_t); + + // diffusion_model->compute(n_threads, x, timesteps, context, c_concat, y, guidance, num_video_frames, controls, control_strength, persistent_work_ctx, output, output_ctx, skip_layers) + diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, work_ctx, &out_tensor, work_ctx); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -890,6 +915,7 @@ class StableDiffusionGGML { -1, controls, control_strength, + work_ctx, &out_cond); } else { diffusion_model->compute(n_threads, @@ -902,6 +928,7 @@ class StableDiffusionGGML { -1, controls, control_strength, + work_ctx, &out_cond); } @@ -922,6 +949,7 @@ class StableDiffusionGGML { -1, controls, control_strength, + work_ctx, &out_uncond); negative_data = (float*)out_uncond->data; } @@ -942,6 +970,7 @@ class StableDiffusionGGML { -1, controls, control_strength, + work_ctx, &out_skip, NULL, skip_layers); @@ -1130,7 +1159,12 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, bool keep_clip_on_cpu, bool keep_control_net_cpu, bool keep_vae_on_cpu, - bool diffusion_flash_attn) { + bool diffusion_flash_attn, + // DeepCache parameters + int dc_cache_interval, + int dc_cache_depth, + int dc_start_steps, + int dc_end_steps) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; @@ -1151,7 +1185,11 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, vae_decode_only, free_params_immediately, lora_model_dir, - rng_type); + rng_type, + dc_cache_interval, + dc_cache_depth, + dc_start_steps, + dc_end_steps); if (sd_ctx->sd == NULL) { return NULL; } @@ -1439,6 +1477,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed); sd_ctx->sd->rng->manual_seed(cur_seed); + + // Reset DeepCache state for the UNet model for this new image/seed + + auto unet_model = std::dynamic_pointer_cast(sd_ctx->sd->diffusion_model); + if (unet_model) { + unet_model->unet.unet.reset_deepcache_state(); + } + + struct ggml_tensor* x_t = init_latent; struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng); @@ -1561,6 +1608,14 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } + + auto unet_model = std::dynamic_pointer_cast(sd_ctx->sd->diffusion_model); + if (unet_model && unet_model->unet.unet.dc_cache_interval_ > 0) { + LOG_DEBUG("Allocating extra memory for DeepCache tensor"); + size_t cache_tensor_size = 1280 * (height/8) * (width/8) * ggml_type_size(sd_ctx->sd->model_wtype); + params.mem_size += cache_tensor_size; + } + params.mem_size += width * height * 3 * sizeof(float); params.mem_size *= batch_count; params.mem_buffer = NULL; @@ -1673,6 +1728,14 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } + + auto unet_model = std::dynamic_pointer_cast(sd_ctx->sd->diffusion_model); + if (unet_model && unet_model->unet.unet.dc_cache_interval_ > 0) { + LOG_DEBUG("Allocating extra memory for DeepCache tensor"); + size_t cache_tensor_size = 1280 * (height/8) * (width/8) * ggml_type_size(sd_ctx->sd->model_wtype); + params.mem_size += cache_tensor_size; + } + params.mem_size += width * height * 3 * sizeof(float) * 3; params.mem_size *= batch_count; params.mem_buffer = NULL; diff --git a/stable-diffusion.h b/stable-diffusion.h index 52dcc848a..054715627 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -150,7 +150,12 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, bool keep_clip_on_cpu, bool keep_control_net_cpu, bool keep_vae_on_cpu, - bool diffusion_flash_attn); + bool diffusion_flash_attn, + // DeepCache parameters + int dc_cache_interval = 0, + int dc_cache_depth = 3, + int dc_start_steps = 0, + int dc_end_steps = 9999); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); diff --git a/unet.hpp b/unet.hpp index 31b7fe986..d48ef4e08 100644 --- a/unet.hpp +++ b/unet.hpp @@ -184,8 +184,39 @@ class UnetModelBlock : public GGMLBlock { int model_channels = 320; int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_SD1, std::map& tensor_types = empty_tensor_types, bool flash_attn = false) - : version(version) { + // DeepCache parameters + int dc_cache_interval_ = 0; + int dc_cache_depth_ = 0; + int dc_start_steps_ = 0; + int dc_end_steps_ = 0; + bool dc_input_cache_ = true; + bool dc_middle_cache_ = true; + bool dc_output_cache_ = true; + + // DeepCache runtime state + float dc_current_time_ = -1.f; + int dc_current_step_ = -1; + int dc_model_step_ = 0; + ggml_tensor* dc_cache_h_ = nullptr; + ggml_tensor* dc_h_to_cache_this_step_ = nullptr; + + UnetModelBlock(SDVersion version = VERSION_SD1, + std::map& tensor_types = empty_tensor_types, + bool flash_attn = false, + int dc_cache_interval = 0, + int dc_cache_depth = 3, + int dc_start_steps = 0, + int dc_end_steps = 9999) + : version(version), + dc_cache_interval_(dc_cache_interval), + dc_cache_depth_(dc_cache_depth), + dc_start_steps_(dc_start_steps), + dc_end_steps_(dc_end_steps) { + LOG_DEBUG("DeepCache: UnetModelBlock CONSTRUCTOR. Received interval: %d, depth: %d, start: %d, end: %d", + dc_cache_interval, dc_cache_depth, dc_start_steps, dc_end_steps); + LOG_DEBUG("DeepCache: UnetModelBlock CONSTRUCTOR. Stored interval: %d, depth: %d, start: %d, end: %d", + this->dc_cache_interval_, this->dc_cache_depth_, this->dc_start_steps_, this->dc_end_steps_); + if (sd_version_is_sd2(version)) { context_dim = 1024; num_head_channels = 64; @@ -344,6 +375,17 @@ class UnetModelBlock : public GGMLBlock { blocks["out.2"] = std::shared_ptr(new Conv2d(model_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); } + void reset_deepcache_state() { + + dc_current_time_ = -1.f; + dc_current_step_ = -1; + dc_model_step_ = -1; + // dc_cache_h_ is expected to be in a work_ctx that will be freed, + // so just nulling the pointer is sufficient. If it were in a persistent + // context, it would need to be explicitly freed here. + dc_cache_h_ = nullptr; + } + struct ggml_tensor* resblock_forward(std::string name, struct ggml_context* ctx, struct ggml_tensor* x, @@ -376,7 +418,8 @@ class UnetModelBlock : public GGMLBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(struct ggml_context* ctx, // This is the per-step compute_ctx + struct ggml_context* persistent_work_ctx, // This is for dc_cache_h_ struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -385,6 +428,41 @@ class UnetModelBlock : public GGMLBlock { int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f) { + + dc_h_to_cache_this_step_ = nullptr; + + // DeepCache state update + bool dc_enabled = dc_cache_interval_ > 0; + LOG_DEBUG("DeepCache: UNet forward. dc_enabled: %d, interval: %d, depth: %d, start: %d, end: %d", + dc_enabled, dc_cache_interval_, dc_cache_depth_, dc_start_steps_, dc_end_steps_); + bool dc_cache_apply = false; + bool dc_step_cache_interval_is_zero = false; + + if (dc_enabled) { + float t_val = ggml_tensor_get_f32(timesteps, 0); + + // Update step counters only on a new sampling step (i.e., new time value) + if (dc_current_time_ != t_val) { + dc_current_time_ = t_val; + dc_model_step_++; // Increment sampling step counter (0, 1, 2, ...) + + // Update cache-application counter + if (dc_start_steps_ <= dc_model_step_ && dc_model_step_ <= dc_end_steps_) { + dc_current_step_++; + } else { + dc_current_step_ = -1; + } + } + + dc_cache_apply = (dc_start_steps_ <= dc_model_step_ && dc_model_step_ <= dc_end_steps_); + + if (dc_current_step_ != -1) { // only if cache is active for this step + dc_step_cache_interval_is_zero = (dc_current_step_ % dc_cache_interval_ == 0); + } + } + LOG_DEBUG("DeepCache: State update done. model_step: %d, current_step: %d, cache_apply: %d, step_cache_interval_is_zero: %d, current_time: %.2f", + dc_model_step_, dc_current_step_, dc_cache_apply, dc_step_cache_interval_is_zero, dc_current_time_); + // x: [N, in_channels, h, w] or [N, in_channels/2, h, w] // timesteps: [N,] // context: [N, max_position, hidden_size] or [1, max_position, hidden_size]. for example, [N, 77, 768] @@ -437,90 +515,194 @@ class UnetModelBlock : public GGMLBlock { // input_blocks std::vector hs; + int conceptual_module_id = -1; // 0-indexed conceptual module ID + bool input_skipped_deeper = false; + auto h = x; // Initialize h with x before the first block operation - // input block 0 - auto h = input_blocks_0_0->forward(ctx, x); + size_t len_mults = channel_mult.size(); + int internal_block_counter = 0; // For naming blocks like "input_blocks.1.0", "input_blocks.2.0" + int ds = 1; + // input block 0 (initial conv) + conceptual_module_id++; // conceptual_module_id is now 0 + h = input_blocks_0_0->forward(ctx, h); // Use h = x as input ggml_set_name(h, "bench-start"); hs.push_back(h); - // input block 1-11 - size_t len_mults = channel_mult.size(); - int input_block_idx = 0; - int ds = 1; - for (int i = 0; i < len_mults; i++) { + + if (dc_enabled && conceptual_module_id == dc_cache_depth_ && dc_cache_apply && dc_input_cache_ && !dc_step_cache_interval_is_zero) { + LOG_DEBUG("DeepCache: INPUT SKIP triggered at conceptual_module_id %d (cache_depth %d). Skipping deeper input blocks.", conceptual_module_id, dc_cache_depth_); + + input_skipped_deeper = true; + goto after_input_blocks_processing_label; + } + + // input_blocks main loop + // len_mults, internal_block_counter, ds are already declared above + + for (int i = 0; i < len_mults; i++) { // Iterating through levels int mult = channel_mult[i]; - for (int j = 0; j < num_res_blocks; j++) { - input_block_idx += 1; - std::string name = "input_blocks." + std::to_string(input_block_idx) + ".0"; - h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w] + for (int j = 0; j < num_res_blocks; j++) { // Iterating through resblocks in a level + internal_block_counter++; + conceptual_module_id++; + + std::string res_name = "input_blocks." + std::to_string(internal_block_counter) + ".0"; + h = resblock_forward(res_name, ctx, h, emb, num_video_frames); if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { - std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; - h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w] + std::string attn_name = "input_blocks." + std::to_string(internal_block_counter) + ".1"; + h = attention_layer_forward(attn_name, ctx, h, context, num_video_frames); } hs.push_back(h); + + if (dc_enabled && conceptual_module_id == dc_cache_depth_ && dc_cache_apply && dc_input_cache_ && !dc_step_cache_interval_is_zero) { + LOG_DEBUG("DeepCache: INPUT SKIP triggered at conceptual_module_id %d (cache_depth %d). Skipping deeper input blocks.", conceptual_module_id, dc_cache_depth_); + input_skipped_deeper = true; + goto after_input_blocks_processing_label; + } } - if (i != len_mults - 1) { + if (i != len_mults - 1) { // If not the last level, there's a downsample block + internal_block_counter++; + conceptual_module_id++; // Downsample is one conceptual module ds *= 2; - input_block_idx += 1; - - std::string name = "input_blocks." + std::to_string(input_block_idx) + ".0"; - auto block = std::dynamic_pointer_cast(blocks[name]); - - h = block->forward(ctx, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))] + std::string down_name = "input_blocks." + std::to_string(internal_block_counter) + ".0"; + auto block = std::dynamic_pointer_cast(blocks[down_name]); + h = block->forward(ctx, h); hs.push_back(h); + + if (dc_enabled && conceptual_module_id == dc_cache_depth_ && dc_cache_apply && dc_input_cache_ && !dc_step_cache_interval_is_zero) { + LOG_DEBUG("DeepCache: INPUT SKIP triggered at conceptual_module_id %d (cache_depth %d). Skipping deeper input blocks.", conceptual_module_id, dc_cache_depth_); + input_skipped_deeper = true; + goto after_input_blocks_processing_label; + } } } - // [N, 4*model_channels, h/8, w/8] - // middle_block - h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + after_input_blocks_processing_label:; + // Middle block - this logic should apply regardless of input_skipped_deeper + // The value of 'h' entering this section is what matters. + if (dc_enabled && dc_cache_apply && dc_middle_cache_ && !dc_step_cache_interval_is_zero) { + LOG_DEBUG("DeepCache: MIDDLE SKIP triggered. Middle block computation skipped."); + + // Middle block's computation is skipped. 'h' remains unchanged from input stage. + } else { + LOG_DEBUG("DeepCache: MIDDLE COMPUTE. Middle block will be computed."); + // Middle block is computed, potentially updating 'h'. + h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); + h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); + h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); + } + // Now, h is the xuh that would be passed to output blocks or cached. if (controls.size() > 0) { auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength); h = ggml_add(ctx, h, cs); // middle control } int control_offset = controls.size() - 2; - + // output_blocks - int output_block_idx = 0; - for (int i = (int)len_mults - 1; i >= 0; i--) { - for (int j = 0; j < num_res_blocks + 1; j++) { + // Determine total number of "conceptual" output blocks + // For SD 1.5, len(unet.output_blocks) is 12. + // This means conceptual_output_id will go from 0 to 11. + int total_output_blocks = 0; + for (int k_level = 0; k_level < len_mults; k_level++) { // Iterate same way as constructor's output_block_idx + for (int k_res = 0; k_res < num_res_blocks + 1; k_res++) { + total_output_blocks++; + } + } + // LOG_DEBUG("DeepCache: Calculated total_output_blocks = %d", total_output_blocks); // Add this if needed + + int temp_ds = 1; + for(int k=0; k < len_mults -1; ++k) temp_ds *=2; + + + int conceptual_output_block_id = -1; + int internal_output_block_idx_for_map = -1; + ds = temp_ds; + + for (int i = (int)len_mults - 1; i >= 0; i--) { + for (int j = 0; j < num_res_blocks + 1; j++) { + conceptual_output_block_id++; + internal_output_block_idx_for_map++; + + // Determine if this block's processing should be skipped due to output caching rule + bool skip_due_to_output_cache_rule = false; + if (dc_enabled && dc_cache_apply && dc_output_cache_ && !dc_step_cache_interval_is_zero) { + if (conceptual_output_block_id < (total_output_blocks - dc_cache_depth_ - 1) ) { + LOG_DEBUG("DeepCache: OUTPUT SKIP evaluated for conceptual_output_block_id %d (cache boundary %d). Block processing will be skipped.", + conceptual_output_block_id, (total_output_blocks - dc_cache_depth_ - 1)); + skip_due_to_output_cache_rule = true; + } + } + + // Cache Update/Reuse for 'h' (main feature path) at the boundary + // This logic applies regardless of whether the *current block's computation with h_skip* is skipped, + // as 'h' itself might be loaded from cache_h. + if (dc_enabled && dc_cache_apply && dc_output_cache_ && + conceptual_output_block_id == (total_output_blocks - dc_cache_depth_ - 1)) { + if (dc_step_cache_interval_is_zero) { // Cache Save Step + LOG_DEBUG("DeepCache: Marking 'h' for OUTPUT CACHE UPDATE at conceptual_output_block_id %d.", conceptual_output_block_id); + dc_h_to_cache_this_step_ = h; + } else { // Cache Reuse Step + if (dc_cache_h_ != nullptr) { + LOG_DEBUG("DeepCache: OUTPUT CACHE REUSE for 'h' at conceptual_output_block_id %d.", conceptual_output_block_id); + h = ggml_dup_tensor(ctx, dc_cache_h_); + } else { + LOG_WARN("DeepCache: OUTPUT CACHE REUSE MISS for 'h' at conceptual_output_block_id %d. dc_cache_h_ is null! THIS IS UNEXPECTED on a reuse step.", conceptual_output_block_id); + } + } + } + + // If this block's processing is skipped by output cache rule, just account for ds and continue + if (skip_due_to_output_cache_rule) { + // If an upsampler was part of this conceptual skipped block, adjust ds + bool was_upsampler_block = (i > 0 && j == num_res_blocks); + if (was_upsampler_block) { + ds /= 2; + } + continue; + } + + // If we reach here, this output block is meant to be processed (not skipped by output cache rule). + // Now, check if we have a corresponding skip connection from the input stage. + if (hs.empty()) { + // This means the input stage was truncated more severely than the output stage expects. + LOG_ERROR("DeepCache/UNet Error: hs stack is empty for PROCESSED output block (conceptual_id %d). Input cache_depth (%d) likely too small for output_depth (%d) or UNet structure.", + conceptual_output_block_id, dc_cache_depth_, dc_cache_depth_ /*This should be output cache depth if different, but we use one dc_cache_depth_*/); + return nullptr; + } auto h_skip = hs.back(); - hs.pop_back(); + hs.pop_back(); // Pop only when processing the block - if (controls.size() > 0) { + // Apply controlnet to h_skip if applicable + if (controls.size() > 0 && control_offset >=0 && (size_t)control_offset < controls.size()) { auto cs = ggml_scale_inplace(ctx, controls[control_offset], control_strength); h_skip = ggml_add(ctx, h_skip, cs); // control net condition - control_offset--; + control_offset--; } + // Actual block computation h = ggml_concat(ctx, h, h_skip, 2); - - std::string name = "output_blocks." + std::to_string(output_block_idx) + ".0"; - - h = resblock_forward(name, ctx, h, emb, num_video_frames); + std::string name_res = "output_blocks." + std::to_string(internal_output_block_idx_for_map) + ".0"; + h = resblock_forward(name_res, ctx, h, emb, num_video_frames); int up_sample_idx = 1; if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { - std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1"; - - h = attention_layer_forward(name, ctx, h, context, num_video_frames); - + std::string name__attn = "output_blocks." + std::to_string(internal_output_block_idx_for_map) + ".1"; + h = attention_layer_forward(name__attn, ctx, h, context, num_video_frames); // Use name__attn up_sample_idx++; } - if (i > 0 && j == num_res_blocks) { - std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx); - auto block = std::dynamic_pointer_cast(blocks[name]); - - h = block->forward(ctx, h); - - ds /= 2; + bool is_upsampler_block = (i > 0 && j == num_res_blocks); + if (is_upsampler_block) { + std::string name_up = "output_blocks." + std::to_string(internal_output_block_idx_for_map) + "." + std::to_string(up_sample_idx); + auto block_up = std::dynamic_pointer_cast(blocks[name_up]); + if (!block_up) { + LOG_ERROR("Failed to get UpSampleBlock: %s. ds=%d, i=%d, j=%d, up_sample_idx=%d, map_id=%d", + name_up.c_str(), ds, i, j, up_sample_idx, internal_output_block_idx_for_map); + return nullptr; + } + h = block_up->forward(ctx, h); + ds /= 2; } - - output_block_idx += 1; } } @@ -528,6 +710,27 @@ class UnetModelBlock : public GGMLBlock { h = out_0->forward(ctx, h); h = ggml_silu_inplace(ctx, h); h = out_2->forward(ctx, h); + + if (dc_h_to_cache_this_step_ != nullptr) { + LOG_DEBUG("DeepCache: Adding cache-save operation to the graph."); + if (dc_cache_h_ == nullptr) { + LOG_DEBUG("DeepCache: Allocating persistent cache tensor."); + dc_cache_h_ = ggml_new_tensor_4d(persistent_work_ctx, + dc_h_to_cache_this_step_->type, + dc_h_to_cache_this_step_->ne[0], + dc_h_to_cache_this_step_->ne[1], + dc_h_to_cache_this_step_->ne[2], + dc_h_to_cache_this_step_->ne[3]); + if (dc_cache_h_ == nullptr) { + LOG_ERROR("DeepCache: Failed to allocate cache tensor in persistent context. Increase work_ctx size."); + return nullptr; + } + } + auto cpy_op = ggml_cpy(ctx, dc_h_to_cache_this_step_, dc_cache_h_); + auto dummy_sum = ggml_sum(ctx, cpy_op); + h = ggml_add(ctx, h, ggml_scale(ctx, dummy_sum, 0.0f)); + } + ggml_set_name(h, "bench-end"); return h; // [N, out_channels, h, w] } @@ -540,8 +743,14 @@ struct UNetModelRunner : public GGMLRunner { std::map& tensor_types, const std::string prefix, SDVersion version = VERSION_SD1, - bool flash_attn = false) - : GGMLRunner(backend), unet(version, tensor_types, flash_attn) { + bool flash_attn = false, + // DeepCache parameters to pass to UnetModelBlock + int dc_cache_interval = 0, + int dc_cache_depth = 3, + int dc_start_steps = 0, + int dc_end_steps = 9999) + : GGMLRunner(backend), unet(version, tensor_types, flash_attn, + dc_cache_interval, dc_cache_depth, dc_start_steps, dc_end_steps) { unet.init(params_ctx, tensor_types, prefix); } @@ -553,7 +762,8 @@ struct UNetModelRunner : public GGMLRunner { unet.get_param_tensors(tensors, prefix); } - struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_cgraph* build_graph(struct ggml_context* persistent_work_ctx, + struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* c_concat = NULL, @@ -578,6 +788,7 @@ struct UNetModelRunner : public GGMLRunner { } struct ggml_tensor* out = unet.forward(compute_ctx, + persistent_work_ctx, x, timesteps, context, @@ -598,6 +809,7 @@ struct UNetModelRunner : public GGMLRunner { struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, + struct ggml_context* persistent_work_ctx, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -609,7 +821,7 @@ struct UNetModelRunner : public GGMLRunner { // c_concat: [N, in_channels, h, w] or [1, in_channels, h, w] // y: [N, adm_in_channels] or [1, adm_in_channels] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength); + return build_graph(persistent_work_ctx, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -648,7 +860,8 @@ struct UNetModelRunner : public GGMLRunner { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, NULL, y, num_video_frames, {}, 0.f, &out, work_ctx); + + compute(8, x, timesteps, context, NULL, y, work_ctx, num_video_frames, {}, 0.f, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out);