Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <chrono>
#include <filesystem>
#include <fstream>
#include <future>
#include <iomanip>
#include <iostream>
#include <mutex>
Expand Down Expand Up @@ -353,6 +354,18 @@ void free_results(sd_image_t* result_images, int num_results) {
free(result_images);
}

static void wait_for_generation(std::future<void>& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) {
std::future_status ft_status;
do {
if (!ft.valid())
break;
ft_status = ft.wait_for(std::chrono::milliseconds(1000));
if (req.is_connection_closed()) {
sd_cancel_generation(sd_ctx, SD_CANCEL_ALL);
}
} while (ft_status != std::future_status::ready);
}

void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
const std::string serve_html_path = svr_params.serve_html_path;
svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
Expand Down Expand Up @@ -498,11 +511,12 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
sd_image_t* results = nullptr;
int num_results = 0;

{
std::future<void> ft = std::async(std::launch::async, [&]() {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
});
wait_for_generation(ft, runtime->sd_ctx, req);

for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) {
Expand Down Expand Up @@ -748,11 +762,12 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
sd_image_t* results = nullptr;
int num_results = 0;

{
std::future<void> ft = std::async(std::launch::async, [&]() {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
});
wait_for_generation(ft, runtime->sd_ctx, req);

json out;
out["created"] = static_cast<long long>(std::time(nullptr));
Expand Down Expand Up @@ -1065,11 +1080,12 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
sd_image_t* results = nullptr;
int num_results = 0;

{
std::future<void> ft = std::async(std::launch::async, [&]() {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
});
wait_for_generation(ft, runtime->sd_ctx, req);

json out;
out["images"] = json::array();
Expand Down
9 changes: 9 additions & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,15 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);

enum sd_cancel_mode_t
{
SD_CANCEL_ALL,
SD_CANCEL_NEW_LATENTS,
SD_CANCEL_RESET
};

SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode);

SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);

Expand Down
51 changes: 51 additions & 0 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "latent-preview.h"
#include "name_conversion.h"

#include <atomic>

const char* model_version_to_str[] = {
"SD 1.x",
"SD 1.x Inpaint",
Expand Down Expand Up @@ -102,6 +104,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) {

/*=============================================== StableDiffusionGGML ================================================*/

static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free,
"sd_cancel_mode_t must be lock-free");

class StableDiffusionGGML {
public:
ggml_backend_t backend = nullptr; // general backend
Expand Down Expand Up @@ -167,6 +172,20 @@ class StableDiffusionGGML {
ggml_backend_free(backend);
}

std::atomic<sd_cancel_mode_t> cancellation_flag;

void set_cancel_flag(enum sd_cancel_mode_t flag) {
cancellation_flag.store(flag, std::memory_order_release);
}

void reset_cancel_flag() {
set_cancel_flag(SD_CANCEL_RESET);
}

enum sd_cancel_mode_t get_cancel_flag() {
return cancellation_flag.load(std::memory_order_acquire);
}

void init_backend() {
#ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend");
Expand Down Expand Up @@ -1625,6 +1644,13 @@ class StableDiffusionGGML {
SamplePreviewContext preview = prepare_sample_preview_context();

auto denoise = [&](const sd::Tensor<float>& x, float sigma, int step) -> sd::Tensor<float> {

enum sd_cancel_mode_t cancel_flag = get_cancel_flag();
if (cancel_flag != SD_CANCEL_RESET) {
LOG_DEBUG("cancelling generation");
return {};
}

if (step == 1 || step == -1) {
pretty_progress(0, (int)steps, 0);
}
Expand Down Expand Up @@ -2392,6 +2418,16 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx);
}

void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode)
{
if (sd_ctx && sd_ctx->sd) {
if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) {
mode = SD_CANCEL_ALL;
}
sd_ctx->sd->set_cancel_flag(mode);
}
}

enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
if (sd_version_is_dit(sd_ctx->sd->version)) {
Expand Down Expand Up @@ -3031,6 +3067,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,
int64_t t0 = ggml_time_ms();

for (size_t i = 0; i < final_latents.size(); i++) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling latent decodings");
break;
}
int64_t t1 = ggml_time_ms();
sd::Tensor<float> image = sd_ctx->sd->decode_first_stage(final_latents[i]);
if (image.empty()) {
Expand Down Expand Up @@ -3069,6 +3109,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
return nullptr;
}

sd_ctx->sd->reset_cancel_flag();

int64_t t0 = ggml_time_ms();
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
GenerationRequest request(sd_ctx, sd_img_gen_params);
Expand Down Expand Up @@ -3104,6 +3146,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
std::vector<sd::Tensor<float>> final_latents;
int64_t denoise_start = ggml_time_ms();
for (int b = 0; b < request.batch_count; b++) {

if (sd_ctx->sd->get_cancel_flag() != SD_CANCEL_RESET) {
LOG_ERROR("cancelling generation");
break;
}

int64_t sampling_start = ggml_time_ms();
int64_t cur_seed = request.seed + b;
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed);
Expand Down Expand Up @@ -3430,6 +3478,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
return nullptr;
}

sd_ctx->sd->reset_cancel_flag();

if (num_frames_out != nullptr) {
*num_frames_out = 0;
}
Expand Down
Loading