diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index ae34530b0..683a07d53 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -2846,7 +2846,8 @@ static std::optional prepare_image_generation_latents(sd {request->width / request->vae_scale_factor, request->height / request->vae_scale_factor, 1, - 1}); + 1}, + sd::ops::InterpolateMode::NearestMax); sd::Tensor init_latent; sd::Tensor control_latent; @@ -2991,8 +2992,12 @@ static std::optional prepare_image_generation_latents(sd latents.ref_latents = std::move(ref_latents); if (sd_version_is_inpaint(sd_ctx->sd->version)) { - latents.denoise_mask = std::move(latent_mask); + latent_mask = sd::ops::max_pool_2d(latent_mask, + {3, 3}, + {1, 1}, + {1, 1}); } + latents.denoise_mask = std::move(latent_mask); return latents; } diff --git a/src/tensor.hpp b/src/tensor.hpp index 33a2bdeaa..33302b056 100644 --- a/src/tensor.hpp +++ b/src/tensor.hpp @@ -815,6 +815,9 @@ namespace sd { namespace ops { enum class InterpolateMode { Nearest, + NearestMax, + NearestMin, + NearestAvg, }; inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) { @@ -1012,12 +1015,16 @@ namespace sd { std::vector output_shape, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false) { - if (mode != InterpolateMode::Nearest) { - tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" + + const bool is_nearest_like_mode = (mode == InterpolateMode::Nearest || + mode == InterpolateMode::NearestMax || + mode == InterpolateMode::NearestMin || + mode == InterpolateMode::NearestAvg); + if (!is_nearest_like_mode) { + tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" + std::to_string(static_cast(mode))); } if (align_corners) { - tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" + + tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } @@ -1044,14 +1051,102 @@ namespace sd { } } + bool has_downsampling = false; + for (int64_t i = 0; i < input.dim(); ++i) { + if (input.shape()[i] > output_shape[i]) { + has_downsampling = true; + break; + } + } + Tensor output(std::move(output_shape)); - for (int64_t flat = 0; flat < output.numel(); ++flat) { - std::vector output_coord = tensor_unravel_index(flat, output.shape()); - std::vector input_coord(static_cast(input.dim()), 0); - for (size_t i = 0; i < static_cast(input.dim()); ++i) { - input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i]; + if (mode == InterpolateMode::Nearest || !has_downsampling) { + for (int64_t flat = 0; flat < output.numel(); ++flat) { + std::vector output_coord = tensor_unravel_index(flat, output.shape()); + std::vector input_coord(static_cast(input.dim()), 0); + for (size_t i = 0; i < static_cast(input.dim()); ++i) { + input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i]; + } + output[flat] = input.index(input_coord); } - output[flat] = input.index(input_coord); + + return output; + } + + auto init_reduction = [&]() -> T { + switch (mode) { + case InterpolateMode::NearestMax: + return std::numeric_limits::lowest(); + case InterpolateMode::NearestMin: + return std::numeric_limits::max(); + case InterpolateMode::NearestAvg: + return T(0); + case InterpolateMode::Nearest: + return T(0); + } + + tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + + std::to_string(static_cast(mode))); + }; + + auto reduce_value = [&](T& acc, const T& sample) { + switch (mode) { + case InterpolateMode::NearestMax: + acc = std::max(acc, sample); + break; + case InterpolateMode::NearestMin: + acc = std::min(acc, sample); + break; + case InterpolateMode::NearestAvg: + acc += sample; + break; + case InterpolateMode::Nearest: + break; + } + }; + + // Reduction modes only differ from nearest mode when downsampling. + for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) { + std::vector output_coord = tensor_unravel_index(flat_out, output.shape()); + + std::vector input_start(output.dim(), 0); + std::vector input_end(output.dim(), 0); + + for (size_t i = 0; i < static_cast(output.dim()); ++i) { + const int64_t input_dim = input.shape()[i]; + const int64_t output_dim = output.shape()[i]; + + input_start[i] = std::max(int64_t(0), static_cast(output_coord[i] * input_dim / output_dim)); + input_end[i] = std::min(input_dim, ((output_coord[i] + 1) * input_dim + output_dim - 1) / output_dim); + } + + T value = init_reduction(); + bool done_window = false; + std::vector current_in_coord = input_start; + + while (!done_window) { + reduce_value(value, input.index(current_in_coord)); + + for (int d = static_cast(output.dim()) - 1; d >= 0; --d) { + if (++current_in_coord[d] < input_end[d]) { + break; + } + current_in_coord[d] = input_start[d]; + if (d == 0) { + done_window = true; + } + } + } + + if (mode == InterpolateMode::NearestAvg) { + int64_t window_size = 1; + for (size_t i = 0; i < static_cast(output.dim()); ++i) { + window_size *= (input_end[i] - input_start[i]); + } + value /= static_cast(window_size); + } + + output[flat_out] = value; } return output; @@ -1063,12 +1158,16 @@ namespace sd { const std::optional>& scale_factor, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false) { - if (mode != InterpolateMode::Nearest) { - tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" + + const bool is_nearest_like_mode = (mode == InterpolateMode::Nearest || + mode == InterpolateMode::NearestMax || + mode == InterpolateMode::NearestMin || + mode == InterpolateMode::NearestAvg); + if (!is_nearest_like_mode) { + tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" + std::to_string(static_cast(mode))); } if (align_corners) { - tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" + + tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" + tensor_shape_to_string(input.shape())); } if (size.has_value() == scale_factor.has_value()) { @@ -1128,6 +1227,80 @@ namespace sd { align_corners); } + template + inline Tensor max_pool_2d(const Tensor& input, + std::vector kernel_size, + std::vector stride, + std::vector padding) { + if (input.dim() < 2) { + tensor_throw_invalid_argument("Tensor max_pool_2d requires input_dim >= 2: input_dim=" + + std::to_string(input.dim()) + ", input_shape=" + + tensor_shape_to_string(input.shape())); + } + if (kernel_size.size() != 2 || stride.size() != 2 || padding.size() != 2) { + tensor_throw_invalid_argument("Tensor max_pool_2d requires kernel_size, stride, and padding to have length 2"); + } + for (size_t i = 0; i < 2; ++i) { + if (kernel_size[i] <= 0) { + tensor_throw_invalid_argument("Tensor max_pool_2d kernel_size must be positive: kernel_size=" + + tensor_shape_to_string(kernel_size)); + } + if (stride[i] <= 0) { + tensor_throw_invalid_argument("Tensor max_pool_2d stride must be positive: stride=" + + tensor_shape_to_string(stride)); + } + if (padding[i] < 0) { + tensor_throw_invalid_argument("Tensor max_pool_2d padding must be non-negative: padding=" + + tensor_shape_to_string(padding)); + } + } + + const int64_t in_height = input.shape()[0]; + const int64_t in_width = input.shape()[1]; + + const int64_t out_height = (in_height + 2 * padding[0] - kernel_size[0]) / stride[0] + 1; + const int64_t out_width = (in_width + 2 * padding[1] - kernel_size[1]) / stride[1] + 1; + + if (out_height <= 0 || out_width <= 0) { + tensor_throw_invalid_argument("max_pool_2d results in invalid output dimensions: " + + std::to_string(out_height) + "x" + std::to_string(out_width)); + } + + std::vector output_shape = input.shape(); + output_shape[0] = out_height; + output_shape[1] = out_width; + + Tensor output(std::move(output_shape)); + + for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) { + std::vector output_coord = tensor_unravel_index(flat_out, output.shape()); + std::vector input_coord = output_coord; + + const int64_t oh = output_coord[0]; + const int64_t ow = output_coord[1]; + + T max_val = std::numeric_limits::lowest(); + bool has_valid_input = false; + + for (int64_t kh = 0; kh < kernel_size[0]; ++kh) { + for (int64_t kw = 0; kw < kernel_size[1]; ++kw) { + const int64_t ih = oh * stride[0] + kh - padding[0]; + const int64_t iw = ow * stride[1] + kw - padding[1]; + + if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) { + input_coord[0] = ih; + input_coord[1] = iw; + max_val = std::max(max_val, input.index(input_coord)); + has_valid_input = true; + } + } + } + + output[flat_out] = has_valid_input ? max_val : T(0); + } + return output; + } + template inline Tensor concat(const Tensor& lhs, const Tensor& rhs, size_t dim) { if (lhs.dim() != rhs.dim()) {