Skip to content

Commit 6404d53

Browse files
committed
feat: Euler Ancestral sampler implementation for flow models
1 parent 738bc8e commit 6404d53

2 files changed

Lines changed: 63 additions & 3 deletions

File tree

src/denoiser.hpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,52 @@ static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
809809
return x;
810810
}
811811

812+
static sd::Tensor<float> sample_euler_flow(denoise_cb_t model,
813+
sd::Tensor<float> x,
814+
const std::vector<float>& sigmas,
815+
std::shared_ptr<RNG> rng,
816+
float eta) {
817+
int steps = static_cast<int>(sigmas.size()) - 1;
818+
for (int i = 0; i < steps; i++) {
819+
float sigma = sigmas[i];
820+
float sigma_to = sigmas[i + 1];
821+
auto denoised_opt = model(x, sigma, i + 1);
822+
if (denoised_opt.empty()) {
823+
return {};
824+
}
825+
sd::Tensor<float> denoised = std::move(denoised_opt);
826+
if (sigma_to == 0) {
827+
// x = x × (sigma_to / sigma) + denoised × (1 - (sigma_to / sigma)) // below
828+
// = x × ( 0 / sigma) + denoised × (1 - (0 / sigma))
829+
// = denoised
830+
x = denoised;
831+
} else if (eta == 0) {
832+
// x = x + d × (sigma_to - sigma)
833+
// = x + ((x - denoised) / sigma) × (sigma_to - sigma)
834+
// = x + (x - denoised) × (sigma_to / sigma - 1)
835+
// = x + x × (sigma_to / sigma) - x - denoised × (sigma_to / sigma) + denoised
836+
// = x × (sigma_to / sigma) + denoised × (1 - (sigma_to / sigma))
837+
float sigma_ratio = sigma_to / sigma;
838+
x = sigma_ratio * x + (1.0 - sigma_ratio) * denoised;
839+
} else {
840+
float downstep_ratio = 1.0f + (sigma_to / sigma - 1.0f) * eta;
841+
float sigma_down = sigma_to * downstep_ratio;
842+
float sigma_ratio = sigma_down / sigma;
843+
x = sigma_ratio * x + (1.0 - sigma_ratio) * denoised;
844+
845+
float alpha_scale = (1 - sigma_to) / (1 - sigma_down);
846+
847+
// sigma_up = √(sigma_to² - sigma_down² × alpha_scale²)
848+
// = √(sigma_to² - sigma_to² × downstep_ratio² × alpha_scale²)
849+
// = sigma_to × √(1 - downstep_ratio² × alpha_scale²)
850+
float term = downstep_ratio * alpha_scale;
851+
float sigma_up = sigma_to * std::sqrt((1.0f + term) * (1.0f - term));
852+
x = alpha_scale * x + sd::Tensor<float>::randn_like(x, rng) * sigma_up;
853+
}
854+
}
855+
return x;
856+
}
857+
812858
static sd::Tensor<float> sample_euler(denoise_cb_t model,
813859
sd::Tensor<float> x,
814860
const std::vector<float>& sigmas) {
@@ -1370,10 +1416,14 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
13701416
sd::Tensor<float> x,
13711417
std::vector<float> sigmas,
13721418
std::shared_ptr<RNG> rng,
1373-
float eta) {
1419+
float eta,
1420+
bool is_flow_denoiser) {
13741421
switch (method) {
13751422
case EULER_A_SAMPLE_METHOD:
1376-
return sample_euler_ancestral(model, std::move(x), sigmas, rng, eta);
1423+
if (is_flow_denoiser)
1424+
return sample_euler_flow(model, std::move(x), sigmas, rng, eta);
1425+
else
1426+
return sample_euler_ancestral(model, std::move(x), sigmas, rng, eta);
13771427
case EULER_SAMPLE_METHOD:
13781428
return sample_euler(model, std::move(x), sigmas);
13791429
case HEUN_SAMPLE_METHOD:

src/stable-diffusion.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,7 @@ class StableDiffusionGGML {
15931593
float eta,
15941594
int shifted_timestep,
15951595
sample_method_t method,
1596+
bool is_flow_denoiser,
15961597
const std::vector<float>& sigmas,
15971598
int start_merge_step,
15981599
const std::vector<sd::Tensor<float>>& ref_latents,
@@ -1791,7 +1792,7 @@ class StableDiffusionGGML {
17911792
return denoised;
17921793
};
17931794

1794-
auto x0_opt = sample_k_diffusion(method, denoise, x_t, sigmas, sampler_rng, eta);
1795+
auto x0_opt = sample_k_diffusion(method, denoise, x_t, sigmas, sampler_rng, eta, is_flow_denoiser);
17951796
if (x0_opt.empty()) {
17961797
LOG_ERROR("Diffusion model sampling failed");
17971798
if (control_net) {
@@ -1909,6 +1910,12 @@ class StableDiffusionGGML {
19091910
flow_denoiser->set_shift(flow_shift);
19101911
}
19111912
}
1913+
1914+
bool is_flow_denoiser() {
1915+
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
1916+
return !!flow_denoiser;
1917+
}
1918+
19121919
};
19131920

19141921
/*================================================= SD API ==================================================*/
@@ -3150,6 +3157,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
31503157
plan.eta,
31513158
request.shifted_timestep,
31523159
plan.sample_method,
3160+
sd_ctx->sd->is_flow_denoiser(),
31533161
plan.sigmas,
31543162
plan.start_merge_step,
31553163
latents.ref_latents,
@@ -3509,6 +3517,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35093517
sd_vid_gen_params->high_noise_sample_params.eta,
35103518
request.shifted_timestep,
35113519
plan.high_noise_sample_method,
3520+
sd_ctx->sd->is_flow_denoiser(),
35123521
high_noise_sigmas,
35133522
-1,
35143523
std::vector<sd::Tensor<float>>{},
@@ -3550,6 +3559,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35503559
sd_vid_gen_params->sample_params.eta,
35513560
sd_vid_gen_params->sample_params.shifted_timestep,
35523561
plan.sample_method,
3562+
sd_ctx->sd->is_flow_denoiser(),
35533563
plan.sigmas,
35543564
-1,
35553565
std::vector<sd::Tensor<float>>{},

0 commit comments

Comments
 (0)