diff --git a/src/model.cpp b/src/model.cpp index d23b97fac..cb19cdd04 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -368,6 +368,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string } else if (is_safetensors_file(file_path)) { LOG_INFO("load %s using safetensors format", file_path.c_str()); return init_from_safetensors_file(file_path, prefix); + } else if (ends_with(file_path, ".safetensors.index.json") && file_exists(file_path)) { + LOG_INFO("load %s using sharded safetensors format", file_path.c_str()); + return init_from_safetensors_index(file_path, prefix); + } else if (ends_with(file_path, ".safetensors") && file_exists(file_path + ".index.json")) { + LOG_INFO("load %s using sharded safetensors format (index found)", file_path.c_str()); + return init_from_safetensors_index(file_path + ".index.json", prefix); } else if (is_zip_file(file_path)) { LOG_INFO("load %s using checkpoint format", file_path.c_str()); return init_from_ckpt_file(file_path, prefix); @@ -386,10 +392,45 @@ void ModelLoader::convert_tensors_name() { String2TensorStorage new_map; for (auto& [_, tensor_storage] : tensor_storage_map) { - auto new_name = convert_tensor_name(tensor_storage.name, version); - // LOG_DEBUG("%s -> %s", tensor_storage.name.c_str(), new_name.c_str()); - tensor_storage.name = new_name; - new_map[new_name] = std::move(tensor_storage); + std::string old_name = tensor_storage.name; + auto new_name = convert_tensor_name(old_name, version); + // LOG_DEBUG("%s -> %s", old_name.c_str(), new_name.c_str()); + + // FLUX.2 diffusers fix: norm_out.linear.weight stores [shift, scale] while + // BFL's final_layer.adaLN_modulation.1.weight stores [scale, shift] (swapped halves). + // When loading from diffusers (name changed), split into two halves and swap their + // file offsets so the split-fuse mechanism loads them in the correct order. + bool needs_half_swap = (old_name != new_name && + sd_version_is_flux2(version) && + ends_with(new_name, "final_layer.adaLN_modulation.1.weight") && + tensor_storage.n_dims == 2 && + tensor_storage.ne[1] % 2 == 0); + + if (needs_half_swap) { + int64_t half_ne1 = tensor_storage.ne[1] / 2; + size_t half_bytes = (size_t)(half_ne1 * tensor_storage.ne[0]) * + ggml_type_size(tensor_storage.type) / ggml_blck_size(tensor_storage.type); + + // Base: second half of file data (shift→scale in BFL order) → dst offset 0 + TensorStorage base = tensor_storage; + base.name = new_name; + base.ne[1] = half_ne1; + base.offset = tensor_storage.offset + half_bytes; + + // Part 1: first half of file data (scale→shift in BFL order) → dst offset = half + TensorStorage part1 = tensor_storage; + part1.name = new_name + ".1"; + part1.ne[1] = half_ne1; + // part1.offset stays at original (first half of file data) + + new_map[base.name] = std::move(base); + new_map[part1.name] = std::move(part1); + LOG_INFO("diffusers fix: split-swap '%s' -> '%s' + '%s.1' (adaLN halves reordered)", + old_name.c_str(), new_name.c_str(), new_name.c_str()); + } else { + tensor_storage.name = new_name; + new_map[new_name] = std::move(tensor_storage); + } } tensor_storage_map.swap(new_map); @@ -498,8 +539,62 @@ ggml_type str_to_ggml_type(const std::string& dtype) { return ttype; } +// Load sharded safetensors via model.safetensors.index.json +bool ModelLoader::init_from_safetensors_index(const std::string& index_path, const std::string& prefix) { + LOG_INFO("loading sharded safetensors from index '%s'", index_path.c_str()); + std::ifstream index_file(index_path); + if (!index_file.is_open()) { + LOG_ERROR("failed to open index file '%s'", index_path.c_str()); + return false; + } + + nlohmann::json index; + try { + index = nlohmann::json::parse(index_file); + } catch (const std::exception& e) { + LOG_ERROR("failed to parse index file '%s': %s", index_path.c_str(), e.what()); + return false; + } + + if (!index.contains("weight_map") || !index["weight_map"].is_object()) { + LOG_ERROR("invalid index file '%s': missing weight_map", index_path.c_str()); + return false; + } + + // Collect unique shard filenames preserving order + std::vector shard_files; + std::set seen; + for (auto& [tensor_name, shard_file] : index["weight_map"].items()) { + std::string fname = shard_file.get(); + if (seen.insert(fname).second) { + shard_files.push_back(fname); + } + } + + // Resolve shard paths relative to index file directory + std::string dir = index_path.substr(0, index_path.find_last_of("/\\")); + int loaded = 0; + for (const auto& shard : shard_files) { + std::string shard_path = path_join(dir, shard); + if (!init_from_safetensors_file(shard_path, prefix)) { + LOG_ERROR("failed to load shard '%s'", shard_path.c_str()); + return false; + } + loaded++; + } + + LOG_INFO("loaded %d shards from '%s'", loaded, index_path.c_str()); + return true; +} + // https://huggingface.co/docs/safetensors/index bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) { + // Check for sharded safetensors (model.safetensors.index.json alongside) + std::string index_path = file_path + ".index.json"; + if (!file_exists(file_path) && file_exists(index_path)) { + return init_from_safetensors_index(index_path, prefix); + } + LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str()); file_paths_.push_back(file_path); size_t file_index = file_paths_.size() - 1; @@ -642,23 +737,57 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const /*================================================= DiffusersModelLoader ==================================================*/ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) { - std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); - std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); - std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); - std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors"); + // Diffusion model: try transformer/ (DiT models: FLUX, SD3) then unet/ (SD1/SDXL) + std::string dit_path = path_join(file_path, "transformer/diffusion_pytorch_model.safetensors"); + std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); + std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); + + bool diffusion_loaded = false; + if (file_exists(dit_path) || file_exists(dit_path + ".index.json")) { + diffusion_loaded = init_from_safetensors_file(dit_path, "unet."); + } else { + diffusion_loaded = init_from_safetensors_file(unet_path, "unet."); + } - if (!init_from_safetensors_file(unet_path, "unet.")) { + if (!diffusion_loaded) { return false; } if (!init_from_safetensors_file(vae_path, "vae.")) { LOG_WARN("Couldn't find working VAE in %s", file_path.c_str()); - // return false; } - if (!init_from_safetensors_file(clip_path, "te.")) { + + // Determine text encoder type from model_index.json + // LLM-based encoders (Qwen, Llama) need "text_encoders.llm." prefix, + // CLIP-based encoders need "te." prefix + std::string te_prefix = "te."; + std::string model_index_path = path_join(file_path, "model_index.json"); + if (file_exists(model_index_path)) { + std::ifstream mi_file(model_index_path); + if (mi_file.is_open()) { + try { + nlohmann::json mi = nlohmann::json::parse(mi_file); + if (mi.contains("text_encoder") && mi["text_encoder"].is_array() && mi["text_encoder"].size() >= 2) { + std::string te_class = mi["text_encoder"][1].get(); + // LLM-based text encoders: Qwen, Llama, Gemma, etc. + if (te_class.find("ForCausalLM") != std::string::npos || + te_class.find("LlamaModel") != std::string::npos) { + te_prefix = "text_encoders.llm."; + LOG_INFO("detected LLM text encoder: %s, using prefix '%s'", te_class.c_str(), te_prefix.c_str()); + } + } + } catch (...) { + LOG_DEBUG("failed to parse model_index.json, using default te prefix"); + } + } + } + + std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); + if (!init_from_safetensors_file(clip_path, te_prefix)) { LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str()); - // return false; } + + std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors"); if (!init_from_safetensors_file(clip_g_path, "te.1.")) { LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str()); } @@ -1063,9 +1192,15 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { is_flux2 = true; } + if (tensor_storage.name.find("double_stream_modulation_img.linear.weight") != std::string::npos) { + is_flux2 = true; + } if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) { has_single_block_47 = true; } + if (tensor_storage.name.find("single_transformer_blocks.47.") != std::string::npos) { + has_single_block_47 = true; + } if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { return VERSION_OVIS_IMAGE; } @@ -1459,6 +1594,20 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread size_t nbytes_to_read = tensor_storage.nbytes_to_read(); + // Split→fuse support: detect partial tensor loads + bool is_split = (tensor_storage.nelements() < ggml_nelements(dst_tensor)); + size_t write_nbytes = is_split + ? ((size_t)tensor_storage.nelements() * ggml_type_size(dst_tensor->type) / ggml_blck_size(dst_tensor->type)) + : ggml_nbytes(dst_tensor); + + if (is_split) { + LOG_INFO("split-fuse write: '%s' is_host=%d, dst_offset=%zu, write_nbytes=%zu, read_nbytes=%zu, src_type=%s, dst_type=%s", + tensor_storage.name.c_str(), + (dst_tensor->buffer == nullptr || ggml_backend_buffer_is_host(dst_tensor->buffer)) ? 1 : 0, + tensor_storage.dst_offset, write_nbytes, nbytes_to_read, + ggml_type_name(tensor_storage.type), ggml_type_name(dst_tensor->type)); + } + auto read_data = [&](char* buf, size_t n) { if (zip != nullptr) { zip_entry_openbyindex(zip, tensor_storage.index_in_zip); @@ -1493,20 +1642,23 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread char* target_buf = nullptr; char* convert_buf = nullptr; if (dst_tensor->buffer == nullptr || ggml_backend_buffer_is_host(dst_tensor->buffer)) { + char* dst_data_ptr = (char*)dst_tensor->data + tensor_storage.dst_offset; if (tensor_storage.type == dst_tensor->type) { - GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); + if (!is_split) { + GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); + } if (tensor_storage.is_f64 || tensor_storage.is_i64) { read_buffer.resize(tensor_storage.nbytes_to_read()); read_buf = (char*)read_buffer.data(); } else { - read_buf = (char*)dst_tensor->data; + read_buf = dst_data_ptr; } - target_buf = (char*)dst_tensor->data; + target_buf = dst_data_ptr; } else { read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); read_buf = (char*)read_buffer.data(); target_buf = read_buf; - convert_buf = (char*)dst_tensor->data; + convert_buf = dst_data_ptr; } } else { read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); @@ -1514,7 +1666,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread target_buf = read_buf; if (tensor_storage.type != dst_tensor->type) { - convert_buffer.resize(ggml_nbytes(dst_tensor)); + convert_buffer.resize(write_nbytes); convert_buf = (char*)convert_buffer.data(); } } @@ -1554,7 +1706,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread if (dst_tensor->buffer != nullptr && !ggml_backend_buffer_is_host(dst_tensor->buffer)) { t0 = ggml_time_ms(); - ggml_backend_tensor_set(dst_tensor, convert_buf, 0, ggml_nbytes(dst_tensor)); + ggml_backend_tensor_set(dst_tensor, convert_buf, tensor_storage.dst_offset, write_nbytes); t1 = ggml_time_ms(); copy_to_backend_time_ms.fetch_add(t1 - t0); } @@ -1615,10 +1767,46 @@ bool ModelLoader::load_tensors(std::map& tensors, tensor_names_in_file.insert(name); } - ggml_tensor* real; + ggml_tensor* real = nullptr; if (tensors.find(name) != tensors.end()) { real = tensors[name]; } else { + // Check if this is a split tensor part (e.g., qkv.weight.1, qkv.weight.2) + // Convention: split parts have suffix .N where N is a positive integer + size_t last_dot = name.rfind('.'); + if (last_dot != std::string::npos) { + const std::string suffix = name.substr(last_dot + 1); + if (!suffix.empty() && suffix.find_first_not_of("0123456789") == std::string::npos) { + int split_idx = std::stoi(suffix); + if (split_idx > 0) { + const std::string base_name = name.substr(0, last_dot); + if (tensors.find(base_name) != tensors.end()) { + real = tensors[base_name]; + // Verify dimensions are consistent with being a split part + if (tensor_storage.ne[0] == real->ne[0] && + real->ne[1] >= tensor_storage.ne[1]) { + // Compute dst_offset by summing sizes of preceding parts + size_t dst_off = 0; + for (int j = 0; j < split_idx; j++) { + std::string part_name = (j == 0) ? base_name : (base_name + "." + std::to_string(j)); + auto it = tensor_storage_map.find(part_name); + if (it != tensor_storage_map.end()) { + dst_off += (size_t)it->second.nelements() * ggml_type_size(real->type) / ggml_blck_size(real->type); + } + } + tensor_storage.dst_offset = dst_off; + LOG_INFO("split-fuse part %d: '%s' -> base '%s', dst_offset=%zu, part_bytes=%zu, dst_total=%zu", + split_idx, name.c_str(), base_name.c_str(), dst_off, + (size_t)tensor_storage.nelements() * ggml_type_size(real->type) / ggml_blck_size(real->type), + ggml_nbytes(real)); + *dst_tensor = real; + return true; + } + } + } + } + } + for (auto& ignore_tensor : ignore_tensors) { if (starts_with(name, ignore_tensor)) { return true; @@ -1633,6 +1821,29 @@ bool ModelLoader::load_tensors(std::map& tensors, real->ne[1] != tensor_storage.ne[1] || real->ne[2] != tensor_storage.ne[2] || real->ne[3] != tensor_storage.ne[3]) { + // Check if this is the base (part 0) of a split tensor group + bool is_split_base = false; + if (tensor_storage.ne[0] == real->ne[0] && + real->ne[1] > tensor_storage.ne[1] && + real->ne[1] % tensor_storage.ne[1] == 0) { + std::string name_1 = name + ".1"; + if (tensor_storage_map.find(name_1) != tensor_storage_map.end()) { + is_split_base = true; + } + } + + if (is_split_base) { + tensor_storage.dst_offset = 0; + LOG_INFO("split-fuse base: '%s', file_ne=[%lld,%lld], dst_ne=[%lld,%lld], dst_offset=0, part_bytes=%zu, dst_total=%zu", + name.c_str(), + (long long)tensor_storage.ne[0], (long long)tensor_storage.ne[1], + (long long)real->ne[0], (long long)real->ne[1], + (size_t)tensor_storage.nelements() * ggml_type_size(real->type) / ggml_blck_size(real->type), + ggml_nbytes(real)); + *dst_tensor = real; + return true; + } + LOG_ERROR( "tensor '%s' has wrong shape in model file: " "got [%d, %d, %d, %d], expected [%d, %d, %d, %d]", diff --git a/src/model.h b/src/model.h index 3af35eb7e..db09a25e7 100644 --- a/src/model.h +++ b/src/model.h @@ -192,6 +192,7 @@ struct TensorStorage { size_t file_index = 0; int index_in_zip = -1; // >= means stored in a zip file uint64_t offset = 0; // offset in file + mutable size_t dst_offset = 0; // byte offset within destination tensor for split→fuse loading TensorStorage() = default; @@ -306,6 +307,7 @@ class ModelLoader { bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_safetensors_index(const std::string& index_path, const std::string& prefix = ""); bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index d5d5e052c..23a8a9b7f 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -615,6 +615,79 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { return name; } +std::string convert_diffusers_dit_to_original_flux2(std::string name) { + int max_double_blocks = 100; + int max_single_blocks = 200; + static std::unordered_map flux2_name_map; + + if (flux2_name_map.empty()) { + // --- time_guidance_embed --- + flux2_name_map["time_guidance_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; + flux2_name_map["time_guidance_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight"; + + // --- context_embedder / x_embedder --- + flux2_name_map["context_embedder.weight"] = "txt_in.weight"; + flux2_name_map["x_embedder.weight"] = "img_in.weight"; + + // --- shared modulations (.linear. → .lin.) --- + flux2_name_map["double_stream_modulation_img.linear.weight"] = "double_stream_modulation_img.lin.weight"; + flux2_name_map["double_stream_modulation_img.linear.bias"] = "double_stream_modulation_img.lin.bias"; + flux2_name_map["double_stream_modulation_txt.linear.weight"] = "double_stream_modulation_txt.lin.weight"; + flux2_name_map["double_stream_modulation_txt.linear.bias"] = "double_stream_modulation_txt.lin.bias"; + flux2_name_map["single_stream_modulation.linear.weight"] = "single_stream_modulation.lin.weight"; + flux2_name_map["single_stream_modulation.linear.bias"] = "single_stream_modulation.lin.bias"; + + // --- double transformer blocks --- + for (int i = 0; i < max_double_blocks; ++i) { + std::string block_prefix = "transformer_blocks." + std::to_string(i) + "."; + std::string dst_prefix = "double_blocks." + std::to_string(i) + "."; + + // img attention + flux2_name_map[block_prefix + "attn.to_q.weight"] = dst_prefix + "img_attn.qkv.weight"; + flux2_name_map[block_prefix + "attn.to_k.weight"] = dst_prefix + "img_attn.qkv.weight.1"; + flux2_name_map[block_prefix + "attn.to_v.weight"] = dst_prefix + "img_attn.qkv.weight.2"; + flux2_name_map[block_prefix + "attn.to_out.0.weight"] = dst_prefix + "img_attn.proj.weight"; + flux2_name_map[block_prefix + "attn.norm_q.weight"] = dst_prefix + "img_attn.norm.query_norm.scale"; + flux2_name_map[block_prefix + "attn.norm_k.weight"] = dst_prefix + "img_attn.norm.key_norm.scale"; + + // txt attention + flux2_name_map[block_prefix + "attn.add_q_proj.weight"] = dst_prefix + "txt_attn.qkv.weight"; + flux2_name_map[block_prefix + "attn.add_k_proj.weight"] = dst_prefix + "txt_attn.qkv.weight.1"; + flux2_name_map[block_prefix + "attn.add_v_proj.weight"] = dst_prefix + "txt_attn.qkv.weight.2"; + flux2_name_map[block_prefix + "attn.to_add_out.weight"] = dst_prefix + "txt_attn.proj.weight"; + flux2_name_map[block_prefix + "attn.norm_added_q.weight"] = dst_prefix + "txt_attn.norm.query_norm.scale"; + flux2_name_map[block_prefix + "attn.norm_added_k.weight"] = dst_prefix + "txt_attn.norm.key_norm.scale"; + + // img mlp (SwiGLU: linear_in/linear_out) + flux2_name_map[block_prefix + "ff.linear_in.weight"] = dst_prefix + "img_mlp.0.weight"; + flux2_name_map[block_prefix + "ff.linear_out.weight"] = dst_prefix + "img_mlp.2.weight"; + + // txt mlp (SwiGLU: linear_in/linear_out) + flux2_name_map[block_prefix + "ff_context.linear_in.weight"] = dst_prefix + "txt_mlp.0.weight"; + flux2_name_map[block_prefix + "ff_context.linear_out.weight"] = dst_prefix + "txt_mlp.2.weight"; + } + + // --- single transformer blocks --- + for (int i = 0; i < max_single_blocks; ++i) { + std::string block_prefix = "single_transformer_blocks." + std::to_string(i) + "."; + std::string dst_prefix = "single_blocks." + std::to_string(i) + "."; + + flux2_name_map[block_prefix + "attn.to_qkv_mlp_proj.weight"] = dst_prefix + "linear1.weight"; + flux2_name_map[block_prefix + "attn.to_out.weight"] = dst_prefix + "linear2.weight"; + flux2_name_map[block_prefix + "attn.norm_q.weight"] = dst_prefix + "norm.query_norm.scale"; + flux2_name_map[block_prefix + "attn.norm_k.weight"] = dst_prefix + "norm.key_norm.scale"; + } + + // --- final layers --- + flux2_name_map["proj_out.weight"] = "final_layer.linear.weight"; + flux2_name_map["norm_out.linear.weight"] = "final_layer.adaLN_modulation.1.weight"; + } + + replace_with_prefix_map(name, flux2_name_map); + + return name; +} + std::string convert_diffusers_dit_to_original_lumina2(std::string name) { int num_layers = 30; int num_refiner_layers = 2; @@ -668,8 +741,10 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_unet_to_original_sdxl(name); } else if (sd_version_is_sd3(version)) { name = convert_diffusers_dit_to_original_sd3(name); - } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { + } else if (sd_version_is_flux(version)) { name = convert_diffusers_dit_to_original_flux(name); + } else if (sd_version_is_flux2(version)) { + name = convert_diffusers_dit_to_original_flux2(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); } else if (sd_version_is_anima(version)) { diff --git a/tests/test_flux2_name_conversion.cpp b/tests/test_flux2_name_conversion.cpp new file mode 100644 index 000000000..9ea386eac --- /dev/null +++ b/tests/test_flux2_name_conversion.cpp @@ -0,0 +1,242 @@ +// Unit test: FLUX.2-klein diffusers→BFL tensor name conversion +// Tests convert_tensor_name() with VERSION_FLUX2_KLEIN through the full +// public API path: prefix mapping → convert_diffusion_model_name() → +// convert_diffusers_dit_to_original_flux2() +// +// Build: +// c++ -std=c++17 -I../src -I../ggml/include -I../thirdparty \ +// tests/test_flux2_name_conversion.cpp \ +// -L build/bin -lstable-diffusion \ +// -L build/ggml/src -lggml -lggml-base -lggml-cpu \ +// -framework Foundation -framework Metal -framework Accelerate \ +// -o build/bin/test_flux2_name_conversion + +#include +#include +#include +#include + +#include "model.h" +#include "name_conversion.h" + +static int g_pass = 0; +static int g_fail = 0; + +// When loading via --diffusion-model, tensors get prefix "model.diffusion_model." +// and convert_tensor_name strips "transformer." → "model.diffusion_model." +// So input names like "transformer_blocks.0.attn.to_q.weight" become +// "model.diffusion_model.transformer_blocks.0.attn.to_q.weight" after prefix map, +// then convert_diffusion_model_name strips "model.diffusion_model." prefix, +// calls convert_diffusers_dit_to_original_flux2(), and re-adds the prefix. + +static void check(const char* test_name, + const std::string& diffusers_name, + const std::string& expected_bfl_name, + SDVersion version = VERSION_FLUX2_KLEIN) { + // Simulate the input as it comes from diffusers file loaded via --diffusion-model: + // The prefix "model.diffusion_model." is added by init_from_file_and_convert_name + std::string input = "model.diffusion_model." + diffusers_name; + std::string expected = "model.diffusion_model." + expected_bfl_name; + + std::string result = convert_tensor_name(input, version); + + if (result == expected) { + g_pass++; + } else { + g_fail++; + fprintf(stderr, "FAIL [%s]\n input: %s\n expected: %s\n got: %s\n\n", + test_name, input.c_str(), expected.c_str(), result.c_str()); + } +} + +int main() { + printf("=== FLUX.2-klein diffusers→BFL name conversion tests ===\n\n"); + + // --------------------------------------------------------------- + // 1. Time/guidance embedders + // --------------------------------------------------------------- + check("time_in.in_layer", + "time_guidance_embed.timestep_embedder.linear_1.weight", + "time_in.in_layer.weight"); + check("time_in.out_layer", + "time_guidance_embed.timestep_embedder.linear_2.weight", + "time_in.out_layer.weight"); + + // --------------------------------------------------------------- + // 2. Input embedders + // --------------------------------------------------------------- + check("txt_in (context_embedder)", + "context_embedder.weight", + "txt_in.weight"); + check("img_in (x_embedder)", + "x_embedder.weight", + "img_in.weight"); + + // --------------------------------------------------------------- + // 3. Shared modulations (.linear. → .lin.) + // --------------------------------------------------------------- + check("double_mod_img.weight", + "double_stream_modulation_img.linear.weight", + "double_stream_modulation_img.lin.weight"); + check("double_mod_img.bias", + "double_stream_modulation_img.linear.bias", + "double_stream_modulation_img.lin.bias"); + check("double_mod_txt.weight", + "double_stream_modulation_txt.linear.weight", + "double_stream_modulation_txt.lin.weight"); + check("double_mod_txt.bias", + "double_stream_modulation_txt.linear.bias", + "double_stream_modulation_txt.lin.bias"); + check("single_mod.weight", + "single_stream_modulation.linear.weight", + "single_stream_modulation.lin.weight"); + check("single_mod.bias", + "single_stream_modulation.linear.bias", + "single_stream_modulation.lin.bias"); + + // --------------------------------------------------------------- + // 4. Double blocks — block 0 (first) + // --------------------------------------------------------------- + // img attention split q/k/v → fused qkv + check("dbl0.img_attn.q", + "transformer_blocks.0.attn.to_q.weight", + "double_blocks.0.img_attn.qkv.weight"); + check("dbl0.img_attn.k", + "transformer_blocks.0.attn.to_k.weight", + "double_blocks.0.img_attn.qkv.weight.1"); + check("dbl0.img_attn.v", + "transformer_blocks.0.attn.to_v.weight", + "double_blocks.0.img_attn.qkv.weight.2"); + check("dbl0.img_attn.proj", + "transformer_blocks.0.attn.to_out.0.weight", + "double_blocks.0.img_attn.proj.weight"); + check("dbl0.img_attn.norm_q", + "transformer_blocks.0.attn.norm_q.weight", + "double_blocks.0.img_attn.norm.query_norm.scale"); + check("dbl0.img_attn.norm_k", + "transformer_blocks.0.attn.norm_k.weight", + "double_blocks.0.img_attn.norm.key_norm.scale"); + + // txt attention split q/k/v + check("dbl0.txt_attn.q", + "transformer_blocks.0.attn.add_q_proj.weight", + "double_blocks.0.txt_attn.qkv.weight"); + check("dbl0.txt_attn.k", + "transformer_blocks.0.attn.add_k_proj.weight", + "double_blocks.0.txt_attn.qkv.weight.1"); + check("dbl0.txt_attn.v", + "transformer_blocks.0.attn.add_v_proj.weight", + "double_blocks.0.txt_attn.qkv.weight.2"); + check("dbl0.txt_attn.proj", + "transformer_blocks.0.attn.to_add_out.weight", + "double_blocks.0.txt_attn.proj.weight"); + check("dbl0.txt_attn.norm_q", + "transformer_blocks.0.attn.norm_added_q.weight", + "double_blocks.0.txt_attn.norm.query_norm.scale"); + check("dbl0.txt_attn.norm_k", + "transformer_blocks.0.attn.norm_added_k.weight", + "double_blocks.0.txt_attn.norm.key_norm.scale"); + + // img MLP (SwiGLU) + check("dbl0.img_mlp.in", + "transformer_blocks.0.ff.linear_in.weight", + "double_blocks.0.img_mlp.0.weight"); + check("dbl0.img_mlp.out", + "transformer_blocks.0.ff.linear_out.weight", + "double_blocks.0.img_mlp.2.weight"); + + // txt MLP (SwiGLU) + check("dbl0.txt_mlp.in", + "transformer_blocks.0.ff_context.linear_in.weight", + "double_blocks.0.txt_mlp.0.weight"); + check("dbl0.txt_mlp.out", + "transformer_blocks.0.ff_context.linear_out.weight", + "double_blocks.0.txt_mlp.2.weight"); + + // --------------------------------------------------------------- + // 5. Double blocks — block 4 (last for klein-4B with depth=5) + // --------------------------------------------------------------- + check("dbl4.img_attn.q", + "transformer_blocks.4.attn.to_q.weight", + "double_blocks.4.img_attn.qkv.weight"); + check("dbl4.txt_attn.v", + "transformer_blocks.4.attn.add_v_proj.weight", + "double_blocks.4.txt_attn.qkv.weight.2"); + check("dbl4.img_mlp.out", + "transformer_blocks.4.ff.linear_out.weight", + "double_blocks.4.img_mlp.2.weight"); + + // --------------------------------------------------------------- + // 6. Single blocks — block 0 (first) + // --------------------------------------------------------------- + check("sgl0.linear1 (fused qkv+mlp)", + "single_transformer_blocks.0.attn.to_qkv_mlp_proj.weight", + "single_blocks.0.linear1.weight"); + check("sgl0.linear2 (out proj)", + "single_transformer_blocks.0.attn.to_out.weight", + "single_blocks.0.linear2.weight"); + check("sgl0.norm_q", + "single_transformer_blocks.0.attn.norm_q.weight", + "single_blocks.0.norm.query_norm.scale"); + check("sgl0.norm_k", + "single_transformer_blocks.0.attn.norm_k.weight", + "single_blocks.0.norm.key_norm.scale"); + + // --------------------------------------------------------------- + // 7. Single blocks — block 19 (last for klein-4B with depth_single=20) + // --------------------------------------------------------------- + check("sgl19.linear1", + "single_transformer_blocks.19.attn.to_qkv_mlp_proj.weight", + "single_blocks.19.linear1.weight"); + check("sgl19.linear2", + "single_transformer_blocks.19.attn.to_out.weight", + "single_blocks.19.linear2.weight"); + + // --------------------------------------------------------------- + // 8. Final layers + // --------------------------------------------------------------- + check("final_layer.linear", + "proj_out.weight", + "final_layer.linear.weight"); + check("final_layer.adaLN", + "norm_out.linear.weight", + "final_layer.adaLN_modulation.1.weight"); + + // --------------------------------------------------------------- + // 9. Higher block indices (for 9B: depth=10 double, depth_single=42) + // --------------------------------------------------------------- + check("dbl9.img_attn.q (9B)", + "transformer_blocks.9.attn.to_q.weight", + "double_blocks.9.img_attn.qkv.weight"); + check("sgl41.linear1 (9B)", + "single_transformer_blocks.41.attn.to_qkv_mlp_proj.weight", + "single_blocks.41.linear1.weight"); + + // --------------------------------------------------------------- + // 10. VERSION_FLUX2 (non-klein) should also route to flux2 converter + // --------------------------------------------------------------- + check("flux2_version_routing", + "double_stream_modulation_img.linear.weight", + "double_stream_modulation_img.lin.weight", + VERSION_FLUX2); + + // --------------------------------------------------------------- + // 11. Identity: names that don't match any mapping should pass through + // --------------------------------------------------------------- + check("unknown_passthrough", + "some_unknown_tensor.weight", + "some_unknown_tensor.weight"); + + // --------------------------------------------------------------- + // Summary + // --------------------------------------------------------------- + printf("\n=== Results: %d passed, %d failed, %d total ===\n", + g_pass, g_fail, g_pass + g_fail); + + if (g_fail > 0) { + printf("FAILED\n"); + return 1; + } + printf("ALL PASSED\n"); + return 0; +}