Skip to content

Commit 9d47c33

Browse files
committed
Fix lora loading when using multiple clip backends
1 parent 046ffc3 commit 9d47c33

2 files changed

Lines changed: 196 additions & 43 deletions

File tree

src/conditioner.hpp

Lines changed: 162 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct Conditioner {
8282
virtual ~Conditioner() = default;
8383

8484
public:
85+
int model_count = 1;
8586
virtual SDCondition get_learned_condition(int n_threads,
8687
const ConditionerParams& conditioner_params) = 0;
8788
virtual void alloc_params_buffer() = 0;
@@ -97,6 +98,11 @@ struct Conditioner {
9798
virtual std::string remove_trigger_from_prompt(const std::string& prompt) {
9899
GGML_ABORT("Not implemented yet!");
99100
}
101+
virtual bool is_cond_stage_model_name_at_index(const std::string& name, int index) {
102+
return true;
103+
}
104+
virtual ggml_backend_t get_params_backend_at_index(int index) = 0;
105+
virtual ggml_backend_t get_runtime_backend_at_index(int index) = 0;
100106
};
101107

102108
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
@@ -139,8 +145,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
139145
LOG_INFO("CLIP-H: using %s backend", ggml_backend_name(clip_backend));
140146
text_model = std::make_shared<CLIPTextModelRunner>(clip_backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32);
141147
} else if (sd_version_is_sdxl(version)) {
148+
model_count = 2;
142149
ggml_backend_t clip_g_backend = clip_backend;
143-
if (backends.size() >= 2){
150+
if (backends.size() >= 2) {
144151
clip_g_backend = backends[1];
145152
if (backends.size() > 2) {
146153
LOG_WARN("More than 2 clip backends provided, but the model only supports 2 text encoders. Ignoring the rest.");
@@ -669,6 +676,42 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
669676
conditioner_params.adm_in_channels,
670677
conditioner_params.zero_out_masked);
671678
}
679+
680+
bool is_cond_stage_model_name_at_index(const std::string& name, int index) override {
681+
if (sd_version_is_sdxl(version)) {
682+
if (index == 0) {
683+
return contains(name, "cond_stage_model.model.transformer");
684+
} else if (index == 1) {
685+
return contains(name, "cond_stage_model.model.1");
686+
} else {
687+
return false;
688+
}
689+
}
690+
return true;
691+
}
692+
693+
ggml_backend_t get_params_backend_at_index(int index){
694+
if (sd_version_is_sdxl(version) && index == 1){
695+
if(text_model2) {
696+
return text_model2->get_params_backend();
697+
}
698+
} else if (text_model) {
699+
return text_model->get_params_backend();
700+
}
701+
return nullptr;
702+
}
703+
704+
ggml_backend_t get_runtime_backend_at_index(int index){
705+
if (sd_version_is_sdxl(version) && index == 1){
706+
if(text_model2) {
707+
return text_model2->get_runtime_backend();
708+
}
709+
} else if (text_model) {
710+
return text_model->get_runtime_backend();
711+
}
712+
return nullptr;
713+
}
714+
672715
};
673716

674717
struct FrozenCLIPVisionEmbedder : public GGMLRunner {
@@ -741,12 +784,14 @@ struct SD3CLIPEmbedder : public Conditioner {
741784
bool use_clip_g = false;
742785
bool use_t5 = false;
743786

787+
model_count = 3;
788+
744789
ggml_backend_t clip_l_backend, clip_g_backend, t5_backend;
745790
if (backends.size() == 1) {
746791
clip_l_backend = clip_g_backend = t5_backend = backends[0];
747792
} else if (backends.size() == 2) {
748793
clip_l_backend = clip_g_backend = backends[0];
749-
t5_backend = backends[1];
794+
t5_backend = backends[1];
750795
} else if (backends.size() >= 3) {
751796
clip_l_backend = backends[0];
752797
clip_g_backend = backends[1];
@@ -1098,6 +1143,42 @@ struct SD3CLIPEmbedder : public Conditioner {
10981143
conditioner_params.clip_skip,
10991144
conditioner_params.zero_out_masked);
11001145
}
1146+
1147+
bool is_cond_stage_model_name_at_index(const std::string& name, int index) override {
1148+
if (index == 0) {
1149+
return contains(name, "text_encoders.clip_l");
1150+
} else if (index == 1) {
1151+
return contains(name, "text_encoders.clip_g");
1152+
} else if (index == 2) {
1153+
return contains(name, "text_encoders.t5xxl");
1154+
} else {
1155+
return false;
1156+
}
1157+
}
1158+
1159+
ggml_backend_t get_params_backend_at_index(int index){
1160+
if (index == 0 && clip_l) {
1161+
return clip_l->get_params_backend();
1162+
} else if (index == 1 && clip_g) {
1163+
return clip_g->get_params_backend();
1164+
} else if (index == 2 && t5) {
1165+
return t5->get_params_backend();
1166+
} else {
1167+
return nullptr;
1168+
}
1169+
}
1170+
1171+
ggml_backend_t get_runtime_backend_at_index(int index){
1172+
if (index == 0 && clip_l) {
1173+
return clip_l->get_runtime_backend();
1174+
} else if (index == 1 && clip_g) {
1175+
return clip_g->get_runtime_backend();
1176+
} else if (index == 2 && t5) {
1177+
return t5->get_runtime_backend();
1178+
} else {
1179+
return nullptr;
1180+
}
1181+
}
11011182
};
11021183

11031184
struct FluxCLIPEmbedder : public Conditioner {
@@ -1113,19 +1194,19 @@ struct FluxCLIPEmbedder : public Conditioner {
11131194
bool use_clip_l = false;
11141195
bool use_t5 = false;
11151196

1197+
model_count = 2;
11161198

11171199
ggml_backend_t clip_l_backend, t5_backend;
11181200
if (backends.size() == 1) {
11191201
clip_l_backend = t5_backend = backends[0];
11201202
} else if (backends.size() >= 2) {
11211203
clip_l_backend = backends[0];
1122-
t5_backend = backends[1];
1204+
t5_backend = backends[1];
11231205
if (backends.size() > 2) {
11241206
LOG_WARN("More than 2 clip backends provided, but the model only supports 2 text encoders. Ignoring the rest.");
11251207
}
11261208
}
11271209

1128-
11291210
for (auto pair : tensor_storage_map) {
11301211
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
11311212
use_clip_l = true;
@@ -1358,6 +1439,36 @@ struct FluxCLIPEmbedder : public Conditioner {
13581439
conditioner_params.clip_skip,
13591440
conditioner_params.zero_out_masked);
13601441
}
1442+
1443+
bool is_cond_stage_model_name_at_index(const std::string& name, int index) override {
1444+
if (index == 0) {
1445+
return contains(name, "text_encoders.clip_l");
1446+
} else if (index == 1) {
1447+
return contains(name, "text_encoders.t5xxl");
1448+
} else {
1449+
return false;
1450+
}
1451+
}
1452+
1453+
ggml_backend_t get_params_backend_at_index(int index){
1454+
if (index == 0 && clip_l) {
1455+
return clip_l->get_params_backend();
1456+
} else if (index == 1 && t5) {
1457+
return t5->get_params_backend();
1458+
} else {
1459+
return nullptr;
1460+
}
1461+
}
1462+
1463+
ggml_backend_t get_runtime_backend_at_index(int index){
1464+
if (index == 0 && clip_l) {
1465+
return clip_l->get_runtime_backend();
1466+
} else if (index == 1 && t5) {
1467+
return t5->get_runtime_backend();
1468+
} else {
1469+
return nullptr;
1470+
}
1471+
}
13611472
};
13621473

13631474
struct T5CLIPEmbedder : public Conditioner {
@@ -1554,6 +1665,20 @@ struct T5CLIPEmbedder : public Conditioner {
15541665
conditioner_params.clip_skip,
15551666
conditioner_params.zero_out_masked);
15561667
}
1668+
1669+
ggml_backend_t get_params_backend_at_index(int index){
1670+
if (t5){
1671+
return t5->get_params_backend();
1672+
}
1673+
return nullptr;
1674+
}
1675+
1676+
ggml_backend_t get_runtime_backend_at_index(int index){
1677+
if (t5){
1678+
return t5->get_runtime_backend();
1679+
}
1680+
return nullptr;
1681+
}
15571682
};
15581683

15591684
struct AnimaConditioner : public Conditioner {
@@ -1566,11 +1691,11 @@ struct AnimaConditioner : public Conditioner {
15661691
const String2TensorStorage& tensor_storage_map = {}) {
15671692
qwen_tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
15681693
llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::QWEN3,
1569-
backend,
1570-
offload_params_to_cpu,
1571-
tensor_storage_map,
1572-
"text_encoders.llm",
1573-
false);
1694+
backend,
1695+
offload_params_to_cpu,
1696+
tensor_storage_map,
1697+
"text_encoders.llm",
1698+
false);
15741699
}
15751700

15761701
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
@@ -1668,6 +1793,20 @@ struct AnimaConditioner : public Conditioner {
16681793
result.c_t5_weights = std::move(t5_weight_tensor);
16691794
return result;
16701795
}
1796+
1797+
ggml_backend_t get_params_backend_at_index(int index){
1798+
if (llm){
1799+
return llm->get_params_backend();
1800+
}
1801+
return nullptr;
1802+
}
1803+
1804+
ggml_backend_t get_runtime_backend_at_index(int index){
1805+
if (llm){
1806+
return llm->get_runtime_backend();
1807+
}
1808+
return nullptr;
1809+
}
16711810
};
16721811

16731812
struct LLMEmbedder : public Conditioner {
@@ -2012,6 +2151,20 @@ struct LLMEmbedder : public Conditioner {
20122151
result.extra_c_crossattns = std::move(extra_hidden_states_vec);
20132152
return result;
20142153
}
2154+
2155+
ggml_backend_t get_params_backend_at_index(int index){
2156+
if (llm){
2157+
return llm->get_params_backend();
2158+
}
2159+
return nullptr;
2160+
}
2161+
2162+
ggml_backend_t get_runtime_backend_at_index(int index){
2163+
if (llm){
2164+
return llm->get_runtime_backend();
2165+
}
2166+
return nullptr;
2167+
}
20152168
};
20162169

20172170
#endif

src/stable-diffusion.cpp

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,14 +1249,6 @@ class StableDiffusionGGML {
12491249
for (auto& kv : lora_state_diff) {
12501250
bool applied = false;
12511251
int64_t t0 = ggml_time_ms();
1252-
// TODO: Fix that
1253-
bool are_clip_backends_similar = true;
1254-
for (auto backend: clip_backends){
1255-
are_clip_backends_similar = are_clip_backends_similar && (clip_backends[0]==backend || ggml_backend_is_cpu(backend));
1256-
}
1257-
if(!are_clip_backends_similar){
1258-
LOG_WARN("Text encoders are running on different backends. This may cause issues when immediately applying LoRAs.");
1259-
}
12601252
auto lora_tensor_filter_diff = [&](const std::string& tensor_name) {
12611253
if (is_diffusion_model_name(tensor_name)) {
12621254
return true;
@@ -1272,19 +1264,22 @@ class StableDiffusionGGML {
12721264
applied = true;
12731265
}
12741266

1275-
auto lora_tensor_filter_cond = [&](const std::string& tensor_name) {
1276-
if (is_cond_stage_model_name(tensor_name)) {
1277-
return true;
1267+
for (int i = 0; i < cond_stage_model->model_count; i++) {
1268+
auto lora_tensor_filter_cond = [&](const std::string& tensor_name) {
1269+
if (is_cond_stage_model_name(tensor_name)) {
1270+
return cond_stage_model->is_cond_stage_model_name_at_index(tensor_name, i);
1271+
}
1272+
return false;
1273+
};
1274+
// TODO: split by model
1275+
LOG_INFO("applying lora to text encoder (%d)", i);
1276+
auto backend = cond_stage_model->get_params_backend_at_index(i);
1277+
lora = load_lora_model_from_file(kv.first, kv.second, backend, lora_tensor_filter_cond);
1278+
if (lora && !lora->lora_tensors.empty()) {
1279+
lora->apply(tensors, version, n_threads);
1280+
lora->free_params_buffer();
1281+
applied = true;
12781282
}
1279-
return false;
1280-
};
1281-
// TODO: split by model
1282-
LOG_INFO("applying lora to text encoders");
1283-
lora = load_lora_model_from_file(kv.first, kv.second, clip_backends[0], lora_tensor_filter_cond);
1284-
if (lora && !lora->lora_tensors.empty()) {
1285-
lora->apply(tensors, version, n_threads);
1286-
lora->free_params_buffer();
1287-
applied = true;
12881283
}
12891284

12901285
auto lora_tensor_filter_first = [&](const std::string& tensor_name) {
@@ -1346,22 +1341,27 @@ class StableDiffusionGGML {
13461341
}
13471342
}
13481343
cond_stage_lora_models = lora_models;
1349-
auto lora_tensor_filter = [&](const std::string& tensor_name) {
1350-
if (is_cond_stage_model_name(tensor_name)) {
1351-
return true;
1352-
}
1353-
return false;
1354-
};
1355-
for (auto& kv : lora_state_diff) {
1356-
const std::string& lora_id = kv.first;
1357-
float multiplier = kv.second;
1358-
//TODO: split by model
1359-
auto lora = load_lora_model_from_file(lora_id, multiplier, clip_backends[0], lora_tensor_filter);
1360-
if (lora && !lora->lora_tensors.empty()) {
1361-
lora->preprocess_lora_tensors(tensors);
1362-
cond_stage_lora_models.push_back(lora);
1344+
1345+
1346+
for(int i=0;i<cond_stage_model->model_count;i++){
1347+
auto lora_tensor_filter_cond = [&](const std::string& tensor_name) {
1348+
if (is_cond_stage_model_name(tensor_name)) {
1349+
return cond_stage_model->is_cond_stage_model_name_at_index(tensor_name, i);
1350+
}
1351+
return false;
1352+
};
1353+
for (auto& kv : lora_state_diff) {
1354+
const std::string& lora_id = kv.first;
1355+
float multiplier = kv.second;
1356+
auto backend = cond_stage_model->get_runtime_backend_at_index(i);
1357+
auto lora = load_lora_model_from_file(kv.first, kv.second, backend, lora_tensor_filter_cond);
1358+
if (lora && !lora->lora_tensors.empty()) {
1359+
lora->preprocess_lora_tensors(tensors);
1360+
cond_stage_lora_models.push_back(lora);
1361+
}
13631362
}
13641363
}
1364+
13651365
auto multi_lora_adapter = std::make_shared<MultiLoraAdapter>(cond_stage_lora_models);
13661366
cond_stage_model->set_weight_adapter(multi_lora_adapter);
13671367
}

0 commit comments

Comments
 (0)