Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2846,7 +2846,8 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
{request->width / request->vae_scale_factor,
request->height / request->vae_scale_factor,
1,
1});
1},
sd::ops::InterpolateMode::NearestMax);

sd::Tensor<float> init_latent;
sd::Tensor<float> control_latent;
Expand Down Expand Up @@ -2991,8 +2992,12 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
latents.ref_latents = std::move(ref_latents);

if (sd_version_is_inpaint(sd_ctx->sd->version)) {
latents.denoise_mask = std::move(latent_mask);
latent_mask = sd::ops::max_pool_2d(latent_mask,
{3, 3},
{1, 1},
{1, 1});
}
latents.denoise_mask = std::move(latent_mask);

return latents;
}
Expand Down
197 changes: 185 additions & 12 deletions src/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,9 @@ namespace sd {
namespace ops {
enum class InterpolateMode {
Nearest,
NearestMax,
NearestMin,
NearestAvg,
};

inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) {
Expand Down Expand Up @@ -1012,12 +1015,16 @@ namespace sd {
std::vector<int64_t> output_shape,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
if (mode != InterpolateMode::Nearest) {
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" +
const bool is_nearest_like_mode = (mode == InterpolateMode::Nearest ||
mode == InterpolateMode::NearestMax ||
mode == InterpolateMode::NearestMin ||
mode == InterpolateMode::NearestAvg);
if (!is_nearest_like_mode) {
tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" +
std::to_string(static_cast<int>(mode)));
}
if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" +
tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
Expand All @@ -1044,14 +1051,102 @@ namespace sd {
}
}

bool has_downsampling = false;
for (int64_t i = 0; i < input.dim(); ++i) {
if (input.shape()[i] > output_shape[i]) {
has_downsampling = true;
break;
}
}

Tensor<T> output(std::move(output_shape));
for (int64_t flat = 0; flat < output.numel(); ++flat) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat, output.shape());
std::vector<int64_t> input_coord(static_cast<size_t>(input.dim()), 0);
for (size_t i = 0; i < static_cast<size_t>(input.dim()); ++i) {
input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i];
if (mode == InterpolateMode::Nearest || !has_downsampling) {
for (int64_t flat = 0; flat < output.numel(); ++flat) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat, output.shape());
std::vector<int64_t> input_coord(static_cast<size_t>(input.dim()), 0);
for (size_t i = 0; i < static_cast<size_t>(input.dim()); ++i) {
input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i];
}
output[flat] = input.index(input_coord);
}
output[flat] = input.index(input_coord);

return output;
}

auto init_reduction = [&]() -> T {
switch (mode) {
case InterpolateMode::NearestMax:
return std::numeric_limits<T>::lowest();
case InterpolateMode::NearestMin:
return std::numeric_limits<T>::max();
case InterpolateMode::NearestAvg:
return T(0);
case InterpolateMode::Nearest:
return T(0);
}

tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" +
std::to_string(static_cast<int>(mode)));
};

auto reduce_value = [&](T& acc, const T& sample) {
switch (mode) {
case InterpolateMode::NearestMax:
acc = std::max(acc, sample);
break;
case InterpolateMode::NearestMin:
acc = std::min(acc, sample);
break;
case InterpolateMode::NearestAvg:
acc += sample;
break;
case InterpolateMode::Nearest:
break;
}
};

// Reduction modes only differ from nearest mode when downsampling.
for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat_out, output.shape());

std::vector<int64_t> input_start(output.dim(), 0);
std::vector<int64_t> input_end(output.dim(), 0);

for (size_t i = 0; i < static_cast<size_t>(output.dim()); ++i) {
const int64_t input_dim = input.shape()[i];
const int64_t output_dim = output.shape()[i];

input_start[i] = std::max(int64_t(0), static_cast<int64_t>(output_coord[i] * input_dim / output_dim));
input_end[i] = std::min(input_dim, ((output_coord[i] + 1) * input_dim + output_dim - 1) / output_dim);
}

T value = init_reduction();
bool done_window = false;
std::vector<int64_t> current_in_coord = input_start;

while (!done_window) {
reduce_value(value, input.index(current_in_coord));

for (int d = static_cast<int>(output.dim()) - 1; d >= 0; --d) {
if (++current_in_coord[d] < input_end[d]) {
break;
}
current_in_coord[d] = input_start[d];
if (d == 0) {
done_window = true;
}
}
}

if (mode == InterpolateMode::NearestAvg) {
int64_t window_size = 1;
for (size_t i = 0; i < static_cast<size_t>(output.dim()); ++i) {
window_size *= (input_end[i] - input_start[i]);
}
value /= static_cast<T>(window_size);
}

output[flat_out] = value;
}

return output;
Expand All @@ -1063,12 +1158,16 @@ namespace sd {
const std::optional<std::vector<double>>& scale_factor,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
if (mode != InterpolateMode::Nearest) {
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" +
const bool is_nearest_like_mode = (mode == InterpolateMode::Nearest ||
mode == InterpolateMode::NearestMax ||
mode == InterpolateMode::NearestMin ||
mode == InterpolateMode::NearestAvg);
if (!is_nearest_like_mode) {
tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" +
std::to_string(static_cast<int>(mode)));
}
if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" +
tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" +
tensor_shape_to_string(input.shape()));
}
if (size.has_value() == scale_factor.has_value()) {
Expand Down Expand Up @@ -1128,6 +1227,80 @@ namespace sd {
align_corners);
}

template <typename T>
inline Tensor<T> max_pool_2d(const Tensor<T>& input,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride,
std::vector<int64_t> padding) {
if (input.dim() < 2) {
tensor_throw_invalid_argument("Tensor max_pool_2d requires input_dim >= 2: input_dim=" +
std::to_string(input.dim()) + ", input_shape=" +
tensor_shape_to_string(input.shape()));
}
if (kernel_size.size() != 2 || stride.size() != 2 || padding.size() != 2) {
tensor_throw_invalid_argument("Tensor max_pool_2d requires kernel_size, stride, and padding to have length 2");
}
for (size_t i = 0; i < 2; ++i) {
if (kernel_size[i] <= 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d kernel_size must be positive: kernel_size=" +
tensor_shape_to_string(kernel_size));
}
if (stride[i] <= 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d stride must be positive: stride=" +
tensor_shape_to_string(stride));
}
if (padding[i] < 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d padding must be non-negative: padding=" +
tensor_shape_to_string(padding));
}
}

const int64_t in_height = input.shape()[0];
const int64_t in_width = input.shape()[1];

const int64_t out_height = (in_height + 2 * padding[0] - kernel_size[0]) / stride[0] + 1;
const int64_t out_width = (in_width + 2 * padding[1] - kernel_size[1]) / stride[1] + 1;

if (out_height <= 0 || out_width <= 0) {
tensor_throw_invalid_argument("max_pool_2d results in invalid output dimensions: " +
std::to_string(out_height) + "x" + std::to_string(out_width));
}

std::vector<int64_t> output_shape = input.shape();
output_shape[0] = out_height;
output_shape[1] = out_width;

Tensor<T> output(std::move(output_shape));

for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat_out, output.shape());
std::vector<int64_t> input_coord = output_coord;

const int64_t oh = output_coord[0];
const int64_t ow = output_coord[1];

T max_val = std::numeric_limits<T>::lowest();
bool has_valid_input = false;

for (int64_t kh = 0; kh < kernel_size[0]; ++kh) {
for (int64_t kw = 0; kw < kernel_size[1]; ++kw) {
const int64_t ih = oh * stride[0] + kh - padding[0];
const int64_t iw = ow * stride[1] + kw - padding[1];

if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
input_coord[0] = ih;
input_coord[1] = iw;
max_val = std::max(max_val, input.index(input_coord));
has_valid_input = true;
}
}
}

output[flat_out] = has_valid_input ? max_val : T(0);
}
return output;
}

template <typename T>
inline Tensor<T> concat(const Tensor<T>& lhs, const Tensor<T>& rhs, size_t dim) {
if (lhs.dim() != rhs.dim()) {
Expand Down
Loading