@@ -82,6 +82,7 @@ struct Conditioner {
8282 virtual ~Conditioner () = default ;
8383
8484public:
85+ int model_count = 1 ;
8586 virtual SDCondition get_learned_condition (int n_threads,
8687 const ConditionerParams& conditioner_params) = 0;
8788 virtual void alloc_params_buffer () = 0;
@@ -97,6 +98,11 @@ struct Conditioner {
9798 virtual std::string remove_trigger_from_prompt (const std::string& prompt) {
9899 GGML_ABORT (" Not implemented yet!" );
99100 }
101+ virtual bool is_cond_stage_model_name_at_index (const std::string& name, int index) {
102+ return true ;
103+ }
104+ virtual ggml_backend_t get_params_backend_at_index (int index) = 0;
105+ virtual ggml_backend_t get_runtime_backend_at_index (int index) = 0;
100106};
101107
102108// ldm.modules.encoders.modules.FrozenCLIPEmbedder
@@ -139,8 +145,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
139145 LOG_INFO (" CLIP-H: using %s backend" , ggml_backend_name (clip_backend));
140146 text_model = std::make_shared<CLIPTextModelRunner>(clip_backend, offload_params_to_cpu, tensor_storage_map, " cond_stage_model.transformer.text_model" , OPEN_CLIP_VIT_H_14, true , force_clip_f32);
141147 } else if (sd_version_is_sdxl (version)) {
148+ model_count = 2 ;
142149 ggml_backend_t clip_g_backend = clip_backend;
143- if (backends.size () >= 2 ){
150+ if (backends.size () >= 2 ) {
144151 clip_g_backend = backends[1 ];
145152 if (backends.size () > 2 ) {
146153 LOG_WARN (" More than 2 clip backends provided, but the model only supports 2 text encoders. Ignoring the rest." );
@@ -669,6 +676,42 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
669676 conditioner_params.adm_in_channels ,
670677 conditioner_params.zero_out_masked );
671678 }
679+
680+ bool is_cond_stage_model_name_at_index (const std::string& name, int index) override {
681+ if (sd_version_is_sdxl (version)) {
682+ if (index == 0 ) {
683+ return contains (name, " cond_stage_model.model.transformer" );
684+ } else if (index == 1 ) {
685+ return contains (name, " cond_stage_model.model.1" );
686+ } else {
687+ return false ;
688+ }
689+ }
690+ return true ;
691+ }
692+
693+ ggml_backend_t get_params_backend_at_index (int index){
694+ if (sd_version_is_sdxl (version) && index == 1 ){
695+ if (text_model2) {
696+ return text_model2->get_params_backend ();
697+ }
698+ } else if (text_model) {
699+ return text_model->get_params_backend ();
700+ }
701+ return nullptr ;
702+ }
703+
704+ ggml_backend_t get_runtime_backend_at_index (int index){
705+ if (sd_version_is_sdxl (version) && index == 1 ){
706+ if (text_model2) {
707+ return text_model2->get_runtime_backend ();
708+ }
709+ } else if (text_model) {
710+ return text_model->get_runtime_backend ();
711+ }
712+ return nullptr ;
713+ }
714+
672715};
673716
674717struct FrozenCLIPVisionEmbedder : public GGMLRunner {
@@ -741,12 +784,14 @@ struct SD3CLIPEmbedder : public Conditioner {
741784 bool use_clip_g = false ;
742785 bool use_t5 = false ;
743786
787+ model_count = 3 ;
788+
744789 ggml_backend_t clip_l_backend, clip_g_backend, t5_backend;
745790 if (backends.size () == 1 ) {
746791 clip_l_backend = clip_g_backend = t5_backend = backends[0 ];
747792 } else if (backends.size () == 2 ) {
748793 clip_l_backend = clip_g_backend = backends[0 ];
749- t5_backend = backends[1 ];
794+ t5_backend = backends[1 ];
750795 } else if (backends.size () >= 3 ) {
751796 clip_l_backend = backends[0 ];
752797 clip_g_backend = backends[1 ];
@@ -1098,6 +1143,42 @@ struct SD3CLIPEmbedder : public Conditioner {
10981143 conditioner_params.clip_skip ,
10991144 conditioner_params.zero_out_masked );
11001145 }
1146+
1147+ bool is_cond_stage_model_name_at_index (const std::string& name, int index) override {
1148+ if (index == 0 ) {
1149+ return contains (name, " text_encoders.clip_l" );
1150+ } else if (index == 1 ) {
1151+ return contains (name, " text_encoders.clip_g" );
1152+ } else if (index == 2 ) {
1153+ return contains (name, " text_encoders.t5xxl" );
1154+ } else {
1155+ return false ;
1156+ }
1157+ }
1158+
1159+ ggml_backend_t get_params_backend_at_index (int index){
1160+ if (index == 0 && clip_l) {
1161+ return clip_l->get_params_backend ();
1162+ } else if (index == 1 && clip_g) {
1163+ return clip_g->get_params_backend ();
1164+ } else if (index == 2 && t5) {
1165+ return t5->get_params_backend ();
1166+ } else {
1167+ return nullptr ;
1168+ }
1169+ }
1170+
1171+ ggml_backend_t get_runtime_backend_at_index (int index){
1172+ if (index == 0 && clip_l) {
1173+ return clip_l->get_runtime_backend ();
1174+ } else if (index == 1 && clip_g) {
1175+ return clip_g->get_runtime_backend ();
1176+ } else if (index == 2 && t5) {
1177+ return t5->get_runtime_backend ();
1178+ } else {
1179+ return nullptr ;
1180+ }
1181+ }
11011182};
11021183
11031184struct FluxCLIPEmbedder : public Conditioner {
@@ -1113,19 +1194,19 @@ struct FluxCLIPEmbedder : public Conditioner {
11131194 bool use_clip_l = false ;
11141195 bool use_t5 = false ;
11151196
1197+ model_count = 2 ;
11161198
11171199 ggml_backend_t clip_l_backend, t5_backend;
11181200 if (backends.size () == 1 ) {
11191201 clip_l_backend = t5_backend = backends[0 ];
11201202 } else if (backends.size () >= 2 ) {
11211203 clip_l_backend = backends[0 ];
1122- t5_backend = backends[1 ];
1204+ t5_backend = backends[1 ];
11231205 if (backends.size () > 2 ) {
11241206 LOG_WARN (" More than 2 clip backends provided, but the model only supports 2 text encoders. Ignoring the rest." );
11251207 }
11261208 }
11271209
1128-
11291210 for (auto pair : tensor_storage_map) {
11301211 if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
11311212 use_clip_l = true ;
@@ -1358,6 +1439,36 @@ struct FluxCLIPEmbedder : public Conditioner {
13581439 conditioner_params.clip_skip ,
13591440 conditioner_params.zero_out_masked );
13601441 }
1442+
1443+ bool is_cond_stage_model_name_at_index (const std::string& name, int index) override {
1444+ if (index == 0 ) {
1445+ return contains (name, " text_encoders.clip_l" );
1446+ } else if (index == 1 ) {
1447+ return contains (name, " text_encoders.t5xxl" );
1448+ } else {
1449+ return false ;
1450+ }
1451+ }
1452+
1453+ ggml_backend_t get_params_backend_at_index (int index){
1454+ if (index == 0 && clip_l) {
1455+ return clip_l->get_params_backend ();
1456+ } else if (index == 1 && t5) {
1457+ return t5->get_params_backend ();
1458+ } else {
1459+ return nullptr ;
1460+ }
1461+ }
1462+
1463+ ggml_backend_t get_runtime_backend_at_index (int index){
1464+ if (index == 0 && clip_l) {
1465+ return clip_l->get_runtime_backend ();
1466+ } else if (index == 1 && t5) {
1467+ return t5->get_runtime_backend ();
1468+ } else {
1469+ return nullptr ;
1470+ }
1471+ }
13611472};
13621473
13631474struct T5CLIPEmbedder : public Conditioner {
@@ -1554,6 +1665,20 @@ struct T5CLIPEmbedder : public Conditioner {
15541665 conditioner_params.clip_skip ,
15551666 conditioner_params.zero_out_masked );
15561667 }
1668+
1669+ ggml_backend_t get_params_backend_at_index (int index){
1670+ if (t5){
1671+ return t5->get_params_backend ();
1672+ }
1673+ return nullptr ;
1674+ }
1675+
1676+ ggml_backend_t get_runtime_backend_at_index (int index){
1677+ if (t5){
1678+ return t5->get_runtime_backend ();
1679+ }
1680+ return nullptr ;
1681+ }
15571682};
15581683
15591684struct AnimaConditioner : public Conditioner {
@@ -1566,11 +1691,11 @@ struct AnimaConditioner : public Conditioner {
15661691 const String2TensorStorage& tensor_storage_map = {}) {
15671692 qwen_tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
15681693 llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::QWEN3,
1569- backend,
1570- offload_params_to_cpu,
1571- tensor_storage_map,
1572- " text_encoders.llm" ,
1573- false );
1694+ backend,
1695+ offload_params_to_cpu,
1696+ tensor_storage_map,
1697+ " text_encoders.llm" ,
1698+ false );
15741699 }
15751700
15761701 void get_param_tensors (std::map<std::string, ggml_tensor*>& tensors) override {
@@ -1668,6 +1793,20 @@ struct AnimaConditioner : public Conditioner {
16681793 result.c_t5_weights = std::move (t5_weight_tensor);
16691794 return result;
16701795 }
1796+
1797+ ggml_backend_t get_params_backend_at_index (int index){
1798+ if (llm){
1799+ return llm->get_params_backend ();
1800+ }
1801+ return nullptr ;
1802+ }
1803+
1804+ ggml_backend_t get_runtime_backend_at_index (int index){
1805+ if (llm){
1806+ return llm->get_runtime_backend ();
1807+ }
1808+ return nullptr ;
1809+ }
16711810};
16721811
16731812struct LLMEmbedder : public Conditioner {
@@ -2012,6 +2151,20 @@ struct LLMEmbedder : public Conditioner {
20122151 result.extra_c_crossattns = std::move (extra_hidden_states_vec);
20132152 return result;
20142153 }
2154+
2155+ ggml_backend_t get_params_backend_at_index (int index){
2156+ if (llm){
2157+ return llm->get_params_backend ();
2158+ }
2159+ return nullptr ;
2160+ }
2161+
2162+ ggml_backend_t get_runtime_backend_at_index (int index){
2163+ if (llm){
2164+ return llm->get_runtime_backend ();
2165+ }
2166+ return nullptr ;
2167+ }
20152168};
20162169
20172170#endif
0 commit comments