diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 55bca154..d6395497 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -607,8 +607,38 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) { typedef std::function on_tile_process; +__STATIC_INLINE__ void +sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) { + + int tile_overlap = (tile_size * tile_overlap_factor); + int non_tile_overlap = tile_size - tile_overlap; + + num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap; + int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim; + + if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) { + // if tiles don't fit perfectly using the desired overlap + // and there is enough room to squeeze an extra tile without overlap becoming >0.5 + num_tiles_dim++; + } + + tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1)); + if (num_tiles_dim <= 2) { + if (small_dim <= tile_size) { + num_tiles_dim = 1; + tile_overlap_factor_dim = 0; + } else { + num_tiles_dim = 2; + tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size; + } + } +} + // Tiling -__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { +__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale, + const int p_tile_size_x, const int p_tile_size_y, + const float tile_overlap_factor, on_tile_process on_processing) { + output = ggml_set_f32(output, 0); int input_width = (int)input->ne[0]; @@ -629,62 +659,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const small_height = input_height; } - int tile_overlap = (tile_size * tile_overlap_factor); - int non_tile_overlap = tile_size - tile_overlap; - - int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap; - int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width; - - if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) { - // if tiles don't fit perfectly using the desired overlap - // and there is enough room to squeeze an extra tile without overlap becoming >0.5 - num_tiles_x++; - } - - float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1)); - if (num_tiles_x <= 2) { - if (small_width <= tile_size) { - num_tiles_x = 1; - tile_overlap_factor_x = 0; - } else { - num_tiles_x = 2; - tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size; - } - } - - int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap; - int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height; + int num_tiles_x; + float tile_overlap_factor_x; + sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor); - if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) { - // if tiles don't fit perfectly using the desired overlap - // and there is enough room to squeeze an extra tile without overlap becoming >0.5 - num_tiles_y++; - } - - float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1)); - if (num_tiles_y <= 2) { - if (small_height <= tile_size) { - num_tiles_y = 1; - tile_overlap_factor_y = 0; - } else { - num_tiles_y = 2; - tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size; - } - } + int num_tiles_y; + float tile_overlap_factor_y; + sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor); LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y); LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor); GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2 - int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x); - int non_tile_overlap_x = tile_size - tile_overlap_x; + int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x); + int non_tile_overlap_x = p_tile_size_x - tile_overlap_x; - int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y); - int non_tile_overlap_y = tile_size - tile_overlap_y; + int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y); + int non_tile_overlap_y = p_tile_size_y - tile_overlap_y; - int tile_size_x = tile_size < small_width ? tile_size : small_width; - int tile_size_y = tile_size < small_height ? tile_size : small_height; + int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width; + int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height; int input_tile_size_x = tile_size_x; int input_tile_size_y = tile_size_y; @@ -773,6 +768,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const ggml_free(tiles_ctx); } +__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, + const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { + sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing); +} + __STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, struct ggml_tensor* a) { const float eps = 1e-6f; // default eps parameter diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1d197148..29bda43c 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1427,23 +1427,91 @@ class StableDiffusionGGML { x->ne[3]); // channels int64_t t0 = ggml_time_ms(); - int tile_size = 32; - // TODO: arg instead of env? + // TODO: args instead of env for tile size / overlap? + + float tile_overlap = 0.5f; + const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP"); + if (SD_TILE_OVERLAP != nullptr) { + std::string sd_tile_overlap_str = SD_TILE_OVERLAP; + try { + tile_overlap = std::stof(sd_tile_overlap_str); + if (tile_overlap < 0.0) { + LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0"); + tile_overlap = 0.0; + } + else if (tile_overlap > 0.5) { + LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5"); + tile_overlap = 0.5; + } + } catch (const std::invalid_argument&) { + LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default"); + } catch (const std::out_of_range&) { + LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default"); + } + } + + int tile_size_x = 32; + int tile_size_y = 32; const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE"); if (SD_TILE_SIZE != nullptr) { + // format is AxB, or just A (equivalent to AxA) + // A and B can be integers (tile size) or floating point + // floating point <= 1 means simple fraction of the latent dimension + // floating point > 1 means number of tiles across that dimension + // a single number gets applied to both + auto get_tile_factor = [tile_overlap](const std::string& factor_str) { + float factor = std::stof(factor_str); + if (factor > 1.0) + factor = 1 / (factor - factor * tile_overlap + tile_overlap); + return factor; + }; + const int latent_x = W / (decode ? 1 : 8); + const int latent_y = H / (decode ? 1 : 8); + const int min_tile_dimension = 4; std::string sd_tile_size_str = SD_TILE_SIZE; + size_t x_pos = sd_tile_size_str.find('x'); try { - tile_size = std::stoi(sd_tile_size_str); + int tmp_x = tile_size_x, tmp_y = tile_size_y; + if (x_pos != std::string::npos) { + std::string tile_x_str = sd_tile_size_str.substr(0, x_pos); + std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1); + if (tile_x_str.find('.') != std::string::npos) { + tmp_x = std::round(latent_x * get_tile_factor(tile_x_str)); + } + else { + tmp_x = std::stoi(tile_x_str); + } + if (tile_y_str.find('.') != std::string::npos) { + tmp_y = std::round(latent_y * get_tile_factor(tile_y_str)); + } + else { + tmp_y = std::stoi(tile_y_str); + } + } + else { + if (sd_tile_size_str.find('.') != std::string::npos) { + float tile_factor = get_tile_factor(sd_tile_size_str); + tmp_x = std::round(latent_x * tile_factor); + tmp_y = std::round(latent_y * tile_factor); + } + else { + tmp_x = tmp_y = std::stoi(sd_tile_size_str); + } + } + tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension); + tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension); } catch (const std::invalid_argument&) { - LOG_WARN("Invalid"); + LOG_WARN("SD_TILE_SIZE is invalid, keeping the default"); } catch (const std::out_of_range&) { - LOG_WARN("OOR"); + LOG_WARN("SD_TILE_SIZE is out of range, keeping the default"); } } + if(!decode){ // TODO: also use and arg for this one? // to keep the compute buffer size consistent - tile_size*=1.30539; + tile_size_x*=1.30539; + tile_size_y*=1.30539; } if (!use_tiny_autoencoder) { if (decode) { @@ -1452,11 +1520,17 @@ class StableDiffusionGGML { ggml_tensor_scale_input(x); } if (vae_tiling) { + if (SD_TILE_SIZE != nullptr) { + LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y); + } + if (SD_TILE_OVERLAP != nullptr) { + LOG_INFO("VAE Tile overlap: %.2f", tile_overlap); + } // split latent in 32x32 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { first_stage_model->compute(n_threads, in, decode, &out); }; - sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling); + sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling); } else { first_stage_model->compute(n_threads, x, decode, &result); }