From d16fd8de86c999547828e906cabc6f6e885a68bc Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Sat, 7 Feb 2026 08:56:24 -0300 Subject: [PATCH 1/2] feat: support for canceling the ongoing generation --- include/stable-diffusion.h | 9 +++++++ src/stable-diffusion.cpp | 51 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index f093bb56c..be64ac351 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -381,6 +381,15 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); +enum sd_cancel_mode_t +{ + SD_CANCEL_ALL, + SD_CANCEL_NEW_LATENTS, + SD_CANCEL_RESET +}; + +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode); + SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 093bed20e..cc4c1f286 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -22,6 +22,8 @@ #include "latent-preview.h" #include "name_conversion.h" +#include + const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", @@ -102,6 +104,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) { /*=============================================== StableDiffusionGGML ================================================*/ +static_assert(std::atomic::is_always_lock_free, + "sd_cancel_mode_t must be lock-free"); + class StableDiffusionGGML { public: ggml_backend_t backend = nullptr; // general backend @@ -167,6 +172,20 @@ class StableDiffusionGGML { ggml_backend_free(backend); } + std::atomic cancellation_flag; + + void set_cancel_flag(enum sd_cancel_mode_t flag) { + cancellation_flag.store(flag, std::memory_order_release); + } + + void reset_cancel_flag() { + set_cancel_flag(SD_CANCEL_RESET); + } + + enum sd_cancel_mode_t get_cancel_flag() { + return cancellation_flag.load(std::memory_order_acquire); + } + void init_backend() { #ifdef SD_USE_CUDA LOG_DEBUG("Using CUDA backend"); @@ -1625,6 +1644,13 @@ class StableDiffusionGGML { SamplePreviewContext preview = prepare_sample_preview_context(); auto denoise = [&](const sd::Tensor& x, float sigma, int step) -> sd::Tensor { + + enum sd_cancel_mode_t cancel_flag = get_cancel_flag(); + if (cancel_flag != SD_CANCEL_RESET) { + LOG_DEBUG("cancelling generation"); + return {}; + } + if (step == 1 || step == -1) { pretty_progress(0, (int)steps, 0); } @@ -2392,6 +2418,16 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) +{ + if (sd_ctx && sd_ctx->sd) { + if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) { + mode = SD_CANCEL_ALL; + } + sd_ctx->sd->set_cancel_flag(mode); + } +} + enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { if (sd_version_is_dit(sd_ctx->sd->version)) { @@ -3031,6 +3067,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, int64_t t0 = ggml_time_ms(); for (size_t i = 0; i < final_latents.size(); i++) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling latent decodings"); + break; + } int64_t t1 = ggml_time_ms(); sd::Tensor image = sd_ctx->sd->decode_first_stage(final_latents[i]); if (image.empty()) { @@ -3069,6 +3109,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s return nullptr; } + sd_ctx->sd->reset_cancel_flag(); + int64_t t0 = ggml_time_ms(); sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; GenerationRequest request(sd_ctx, sd_img_gen_params); @@ -3104,6 +3146,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s std::vector> final_latents; int64_t denoise_start = ggml_time_ms(); for (int b = 0; b < request.batch_count; b++) { + + if (sd_ctx->sd->get_cancel_flag() != SD_CANCEL_RESET) { + LOG_ERROR("cancelling generation"); + break; + } + int64_t sampling_start = ggml_time_ms(); int64_t cur_seed = request.seed + b; LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed); @@ -3430,6 +3478,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) { return nullptr; } + + sd_ctx->sd->reset_cancel_flag(); + if (num_frames_out != nullptr) { *num_frames_out = 0; } From b1688530f03e63a1cf8c10a9558169f0929d78ad Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Thu, 12 Feb 2026 17:44:26 -0300 Subject: [PATCH 2/2] feat(server): cancel current generation on client disconnect Co-authored-by: donington --- examples/server/main.cpp | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/examples/server/main.cpp b/examples/server/main.cpp index 6a5036975..72da4b0db 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -353,6 +354,18 @@ void free_results(sd_image_t* result_images, int num_results) { free(result_images); } +static void wait_for_generation(std::future& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) { + std::future_status ft_status; + do { + if (!ft.valid()) + break; + ft_status = ft.wait_for(std::chrono::milliseconds(1000)); + if (req.is_connection_closed()) { + sd_cancel_generation(sd_ctx, SD_CANCEL_ALL); + } + } while (ft_status != std::future_status::ready); +} + void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) { const std::string serve_html_path = svr_params.serve_html_path; svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) { @@ -498,11 +511,12 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { sd_image_t* results = nullptr; int num_results = 0; - { + std::future ft = std::async(std::launch::async, [&]() { std::lock_guard lock(*runtime->sd_ctx_mutex); results = generate_image(runtime->sd_ctx, &img_gen_params); num_results = gen_params.batch_count; - } + }); + wait_for_generation(ft, runtime->sd_ctx, req); for (int i = 0; i < num_results; i++) { if (results[i].data == nullptr) { @@ -748,11 +762,12 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { sd_image_t* results = nullptr; int num_results = 0; - { + std::future ft = std::async(std::launch::async, [&]() { std::lock_guard lock(*runtime->sd_ctx_mutex); results = generate_image(runtime->sd_ctx, &img_gen_params); num_results = gen_params.batch_count; - } + }); + wait_for_generation(ft, runtime->sd_ctx, req); json out; out["created"] = static_cast(std::time(nullptr)); @@ -1065,11 +1080,12 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) { sd_image_t* results = nullptr; int num_results = 0; - { + std::future ft = std::async(std::launch::async, [&]() { std::lock_guard lock(*runtime->sd_ctx_mutex); results = generate_image(runtime->sd_ctx, &img_gen_params); num_results = gen_params.batch_count; - } + }); + wait_for_generation(ft, runtime->sd_ctx, req); json out; out["images"] = json::array();