Skip to content

non-square VAE tiling #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 48 additions & 48 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,38 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {

typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;

__STATIC_INLINE__ void
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {

int tile_overlap = (tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;

num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;

if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
// if tiles don't fit perfectly using the desired overlap
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
num_tiles_dim++;
}
Comment on lines +619 to +623
Copy link
Owner

@stduhpf stduhpf Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wbruna I think it's actually because of this.
If I remember correctly I added this to make sure the overlap was preferably bigger rather than smaller than the target (because less overlap tend to cause more noticable transitions).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If your fix with rounding doesn't work, removing these would work, though I think it's preferable to keep it.

Copy link
Owner

@stduhpf stduhpf Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be working fine so far, no matter the overlap.


tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
if (num_tiles_dim <= 2) {
if (small_dim <= tile_size) {
num_tiles_dim = 1;
tile_overlap_factor_dim = 0;
} else {
num_tiles_dim = 2;
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
}
}
}

// Tiling
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
const int p_tile_size_x, const int p_tile_size_y,
const float tile_overlap_factor, on_tile_process on_processing) {

output = ggml_set_f32(output, 0);

int input_width = (int)input->ne[0];
Expand All @@ -629,62 +659,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
small_height = input_height;
}

int tile_overlap = (tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;

int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;

if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
// if tiles don't fit perfectly using the desired overlap
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
num_tiles_x++;
}

float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
if (num_tiles_x <= 2) {
if (small_width <= tile_size) {
num_tiles_x = 1;
tile_overlap_factor_x = 0;
} else {
num_tiles_x = 2;
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
}
}

int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
int num_tiles_x;
float tile_overlap_factor_x;
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);

if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
// if tiles don't fit perfectly using the desired overlap
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
num_tiles_y++;
}

float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
if (num_tiles_y <= 2) {
if (small_height <= tile_size) {
num_tiles_y = 1;
tile_overlap_factor_y = 0;
} else {
num_tiles_y = 2;
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
}
}
int num_tiles_y;
float tile_overlap_factor_y;
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);

LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);

GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2

int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
int non_tile_overlap_x = tile_size - tile_overlap_x;
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;

int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
int non_tile_overlap_y = tile_size - tile_overlap_y;
int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;

int tile_size_x = tile_size < small_width ? tile_size : small_width;
int tile_size_y = tile_size < small_height ? tile_size : small_height;
int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;

int input_tile_size_x = tile_size_x;
int input_tile_size_y = tile_size_y;
Expand Down Expand Up @@ -773,6 +768,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
ggml_free(tiles_ctx);
}

__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
}

__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
struct ggml_tensor* a) {
const float eps = 1e-6f; // default eps parameter
Expand Down
88 changes: 81 additions & 7 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1427,23 +1427,91 @@ class StableDiffusionGGML {
x->ne[3]); // channels
int64_t t0 = ggml_time_ms();

int tile_size = 32;
// TODO: arg instead of env?
// TODO: args instead of env for tile size / overlap?

float tile_overlap = 0.5f;
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
if (SD_TILE_OVERLAP != nullptr) {
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
try {
tile_overlap = std::stof(sd_tile_overlap_str);
if (tile_overlap < 0.0) {
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
tile_overlap = 0.0;
}
else if (tile_overlap > 0.5) {
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
tile_overlap = 0.5;
}
} catch (const std::invalid_argument&) {
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
} catch (const std::out_of_range&) {
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
}
}

int tile_size_x = 32;
int tile_size_y = 32;
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
if (SD_TILE_SIZE != nullptr) {
// format is AxB, or just A (equivalent to AxA)
// A and B can be integers (tile size) or floating point
// floating point <= 1 means simple fraction of the latent dimension
// floating point > 1 means number of tiles across that dimension
// a single number gets applied to both
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
float factor = std::stof(factor_str);
if (factor > 1.0)
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
return factor;
};
const int latent_x = W / (decode ? 1 : 8);
const int latent_y = H / (decode ? 1 : 8);
const int min_tile_dimension = 4;
std::string sd_tile_size_str = SD_TILE_SIZE;
size_t x_pos = sd_tile_size_str.find('x');
try {
tile_size = std::stoi(sd_tile_size_str);
int tmp_x = tile_size_x, tmp_y = tile_size_y;
if (x_pos != std::string::npos) {
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
if (tile_x_str.find('.') != std::string::npos) {
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
}
else {
tmp_x = std::stoi(tile_x_str);
}
if (tile_y_str.find('.') != std::string::npos) {
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
}
else {
tmp_y = std::stoi(tile_y_str);
}
}
else {
if (sd_tile_size_str.find('.') != std::string::npos) {
float tile_factor = get_tile_factor(sd_tile_size_str);
tmp_x = std::round(latent_x * tile_factor);
tmp_y = std::round(latent_y * tile_factor);
}
else {
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
}
}
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
} catch (const std::invalid_argument&) {
LOG_WARN("Invalid");
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
} catch (const std::out_of_range&) {
LOG_WARN("OOR");
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
}
}

if(!decode){
// TODO: also use and arg for this one?
// to keep the compute buffer size consistent
tile_size*=1.30539;
tile_size_x*=1.30539;
tile_size_y*=1.30539;
}
if (!use_tiny_autoencoder) {
if (decode) {
Expand All @@ -1452,11 +1520,17 @@ class StableDiffusionGGML {
ggml_tensor_scale_input(x);
}
if (vae_tiling) {
if (SD_TILE_SIZE != nullptr) {
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
}
if (SD_TILE_OVERLAP != nullptr) {
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
}
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else {
first_stage_model->compute(n_threads, x, decode, &result);
}
Expand Down