Skip to content
249 changes: 230 additions & 19 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
} else if (is_safetensors_file(file_path)) {
LOG_INFO("load %s using safetensors format", file_path.c_str());
return init_from_safetensors_file(file_path, prefix);
} else if (ends_with(file_path, ".safetensors.index.json") && file_exists(file_path)) {
LOG_INFO("load %s using sharded safetensors format", file_path.c_str());
return init_from_safetensors_index(file_path, prefix);
} else if (ends_with(file_path, ".safetensors") && file_exists(file_path + ".index.json")) {
LOG_INFO("load %s using sharded safetensors format (index found)", file_path.c_str());
return init_from_safetensors_index(file_path + ".index.json", prefix);
} else if (is_zip_file(file_path)) {
LOG_INFO("load %s using checkpoint format", file_path.c_str());
return init_from_ckpt_file(file_path, prefix);
Expand All @@ -386,10 +392,45 @@ void ModelLoader::convert_tensors_name() {
String2TensorStorage new_map;

for (auto& [_, tensor_storage] : tensor_storage_map) {
auto new_name = convert_tensor_name(tensor_storage.name, version);
// LOG_DEBUG("%s -> %s", tensor_storage.name.c_str(), new_name.c_str());
tensor_storage.name = new_name;
new_map[new_name] = std::move(tensor_storage);
std::string old_name = tensor_storage.name;
auto new_name = convert_tensor_name(old_name, version);
// LOG_DEBUG("%s -> %s", old_name.c_str(), new_name.c_str());

// FLUX.2 diffusers fix: norm_out.linear.weight stores [shift, scale] while
// BFL's final_layer.adaLN_modulation.1.weight stores [scale, shift] (swapped halves).
// When loading from diffusers (name changed), split into two halves and swap their
// file offsets so the split-fuse mechanism loads them in the correct order.
bool needs_half_swap = (old_name != new_name &&
sd_version_is_flux2(version) &&
ends_with(new_name, "final_layer.adaLN_modulation.1.weight") &&
tensor_storage.n_dims == 2 &&
tensor_storage.ne[1] % 2 == 0);

if (needs_half_swap) {
int64_t half_ne1 = tensor_storage.ne[1] / 2;
size_t half_bytes = (size_t)(half_ne1 * tensor_storage.ne[0]) *
ggml_type_size(tensor_storage.type) / ggml_blck_size(tensor_storage.type);

// Base: second half of file data (shift→scale in BFL order) → dst offset 0
TensorStorage base = tensor_storage;
base.name = new_name;
base.ne[1] = half_ne1;
base.offset = tensor_storage.offset + half_bytes;

// Part 1: first half of file data (scale→shift in BFL order) → dst offset = half
TensorStorage part1 = tensor_storage;
part1.name = new_name + ".1";
part1.ne[1] = half_ne1;
// part1.offset stays at original (first half of file data)

new_map[base.name] = std::move(base);
new_map[part1.name] = std::move(part1);
LOG_INFO("diffusers fix: split-swap '%s' -> '%s' + '%s.1' (adaLN halves reordered)",
old_name.c_str(), new_name.c_str(), new_name.c_str());
} else {
tensor_storage.name = new_name;
new_map[new_name] = std::move(tensor_storage);
}
}

tensor_storage_map.swap(new_map);
Expand Down Expand Up @@ -498,8 +539,62 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
return ttype;
}

// Load sharded safetensors via model.safetensors.index.json
bool ModelLoader::init_from_safetensors_index(const std::string& index_path, const std::string& prefix) {
LOG_INFO("loading sharded safetensors from index '%s'", index_path.c_str());
std::ifstream index_file(index_path);
if (!index_file.is_open()) {
LOG_ERROR("failed to open index file '%s'", index_path.c_str());
return false;
}

nlohmann::json index;
try {
index = nlohmann::json::parse(index_file);
} catch (const std::exception& e) {
LOG_ERROR("failed to parse index file '%s': %s", index_path.c_str(), e.what());
return false;
}

if (!index.contains("weight_map") || !index["weight_map"].is_object()) {
LOG_ERROR("invalid index file '%s': missing weight_map", index_path.c_str());
return false;
}

// Collect unique shard filenames preserving order
std::vector<std::string> shard_files;
std::set<std::string> seen;
for (auto& [tensor_name, shard_file] : index["weight_map"].items()) {
std::string fname = shard_file.get<std::string>();
if (seen.insert(fname).second) {
shard_files.push_back(fname);
}
}

// Resolve shard paths relative to index file directory
std::string dir = index_path.substr(0, index_path.find_last_of("/\\"));
int loaded = 0;
for (const auto& shard : shard_files) {
std::string shard_path = path_join(dir, shard);
if (!init_from_safetensors_file(shard_path, prefix)) {
LOG_ERROR("failed to load shard '%s'", shard_path.c_str());
return false;
}
loaded++;
}

LOG_INFO("loaded %d shards from '%s'", loaded, index_path.c_str());
return true;
}

// https://huggingface.co/docs/safetensors/index
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
// Check for sharded safetensors (model.safetensors.index.json alongside)
std::string index_path = file_path + ".index.json";
if (!file_exists(file_path) && file_exists(index_path)) {
return init_from_safetensors_index(index_path, prefix);
}

LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str());
file_paths_.push_back(file_path);
size_t file_index = file_paths_.size() - 1;
Expand Down Expand Up @@ -642,23 +737,57 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
/*================================================= DiffusersModelLoader ==================================================*/

bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
// Diffusion model: try transformer/ (DiT models: FLUX, SD3) then unet/ (SD1/SDXL)
std::string dit_path = path_join(file_path, "transformer/diffusion_pytorch_model.safetensors");
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");

bool diffusion_loaded = false;
if (file_exists(dit_path) || file_exists(dit_path + ".index.json")) {
diffusion_loaded = init_from_safetensors_file(dit_path, "unet.");
} else {
diffusion_loaded = init_from_safetensors_file(unet_path, "unet.");
}

if (!init_from_safetensors_file(unet_path, "unet.")) {
if (!diffusion_loaded) {
return false;
}

if (!init_from_safetensors_file(vae_path, "vae.")) {
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_path, "te.")) {

// Determine text encoder type from model_index.json
// LLM-based encoders (Qwen, Llama) need "text_encoders.llm." prefix,
// CLIP-based encoders need "te." prefix
std::string te_prefix = "te.";
std::string model_index_path = path_join(file_path, "model_index.json");
if (file_exists(model_index_path)) {
std::ifstream mi_file(model_index_path);
if (mi_file.is_open()) {
try {
nlohmann::json mi = nlohmann::json::parse(mi_file);
if (mi.contains("text_encoder") && mi["text_encoder"].is_array() && mi["text_encoder"].size() >= 2) {
std::string te_class = mi["text_encoder"][1].get<std::string>();
// LLM-based text encoders: Qwen, Llama, Gemma, etc.
if (te_class.find("ForCausalLM") != std::string::npos ||
te_class.find("LlamaModel") != std::string::npos) {
te_prefix = "text_encoders.llm.";
LOG_INFO("detected LLM text encoder: %s, using prefix '%s'", te_class.c_str(), te_prefix.c_str());
}
}
} catch (...) {
LOG_DEBUG("failed to parse model_index.json, using default te prefix");
}
}
}

std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
if (!init_from_safetensors_file(clip_path, te_prefix)) {
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
// return false;
}

std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
}
Expand Down Expand Up @@ -1063,9 +1192,15 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
is_flux2 = true;
}
if (tensor_storage.name.find("double_stream_modulation_img.linear.weight") != std::string::npos) {
is_flux2 = true;
}
if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) {
has_single_block_47 = true;
}
if (tensor_storage.name.find("single_transformer_blocks.47.") != std::string::npos) {
has_single_block_47 = true;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
return VERSION_OVIS_IMAGE;
}
Expand Down Expand Up @@ -1459,6 +1594,20 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread

size_t nbytes_to_read = tensor_storage.nbytes_to_read();

// Split→fuse support: detect partial tensor loads
bool is_split = (tensor_storage.nelements() < ggml_nelements(dst_tensor));
size_t write_nbytes = is_split
? ((size_t)tensor_storage.nelements() * ggml_type_size(dst_tensor->type) / ggml_blck_size(dst_tensor->type))
: ggml_nbytes(dst_tensor);

if (is_split) {
LOG_INFO("split-fuse write: '%s' is_host=%d, dst_offset=%zu, write_nbytes=%zu, read_nbytes=%zu, src_type=%s, dst_type=%s",
tensor_storage.name.c_str(),
(dst_tensor->buffer == nullptr || ggml_backend_buffer_is_host(dst_tensor->buffer)) ? 1 : 0,
tensor_storage.dst_offset, write_nbytes, nbytes_to_read,
ggml_type_name(tensor_storage.type), ggml_type_name(dst_tensor->type));
}

auto read_data = [&](char* buf, size_t n) {
if (zip != nullptr) {
zip_entry_openbyindex(zip, tensor_storage.index_in_zip);
Expand Down Expand Up @@ -1493,28 +1642,31 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
char* target_buf = nullptr;
char* convert_buf = nullptr;
if (dst_tensor->buffer == nullptr || ggml_backend_buffer_is_host(dst_tensor->buffer)) {
char* dst_data_ptr = (char*)dst_tensor->data + tensor_storage.dst_offset;
if (tensor_storage.type == dst_tensor->type) {
GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes());
if (!is_split) {
GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes());
}
if (tensor_storage.is_f64 || tensor_storage.is_i64) {
read_buffer.resize(tensor_storage.nbytes_to_read());
read_buf = (char*)read_buffer.data();
} else {
read_buf = (char*)dst_tensor->data;
read_buf = dst_data_ptr;
}
target_buf = (char*)dst_tensor->data;
target_buf = dst_data_ptr;
} else {
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
read_buf = (char*)read_buffer.data();
target_buf = read_buf;
convert_buf = (char*)dst_tensor->data;
convert_buf = dst_data_ptr;
}
} else {
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
read_buf = (char*)read_buffer.data();
target_buf = read_buf;

if (tensor_storage.type != dst_tensor->type) {
convert_buffer.resize(ggml_nbytes(dst_tensor));
convert_buffer.resize(write_nbytes);
convert_buf = (char*)convert_buffer.data();
}
}
Expand Down Expand Up @@ -1554,7 +1706,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread

if (dst_tensor->buffer != nullptr && !ggml_backend_buffer_is_host(dst_tensor->buffer)) {
t0 = ggml_time_ms();
ggml_backend_tensor_set(dst_tensor, convert_buf, 0, ggml_nbytes(dst_tensor));
ggml_backend_tensor_set(dst_tensor, convert_buf, tensor_storage.dst_offset, write_nbytes);
t1 = ggml_time_ms();
copy_to_backend_time_ms.fetch_add(t1 - t0);
}
Expand Down Expand Up @@ -1615,10 +1767,46 @@ bool ModelLoader::load_tensors(std::map<std::string, ggml_tensor*>& tensors,
tensor_names_in_file.insert(name);
}

ggml_tensor* real;
ggml_tensor* real = nullptr;
if (tensors.find(name) != tensors.end()) {
real = tensors[name];
} else {
// Check if this is a split tensor part (e.g., qkv.weight.1, qkv.weight.2)
// Convention: split parts have suffix .N where N is a positive integer
size_t last_dot = name.rfind('.');
if (last_dot != std::string::npos) {
const std::string suffix = name.substr(last_dot + 1);
if (!suffix.empty() && suffix.find_first_not_of("0123456789") == std::string::npos) {
int split_idx = std::stoi(suffix);
if (split_idx > 0) {
const std::string base_name = name.substr(0, last_dot);
if (tensors.find(base_name) != tensors.end()) {
real = tensors[base_name];
// Verify dimensions are consistent with being a split part
if (tensor_storage.ne[0] == real->ne[0] &&
real->ne[1] >= tensor_storage.ne[1]) {
// Compute dst_offset by summing sizes of preceding parts
size_t dst_off = 0;
for (int j = 0; j < split_idx; j++) {
std::string part_name = (j == 0) ? base_name : (base_name + "." + std::to_string(j));
auto it = tensor_storage_map.find(part_name);
if (it != tensor_storage_map.end()) {
dst_off += (size_t)it->second.nelements() * ggml_type_size(real->type) / ggml_blck_size(real->type);
}
}
tensor_storage.dst_offset = dst_off;
LOG_INFO("split-fuse part %d: '%s' -> base '%s', dst_offset=%zu, part_bytes=%zu, dst_total=%zu",
split_idx, name.c_str(), base_name.c_str(), dst_off,
(size_t)tensor_storage.nelements() * ggml_type_size(real->type) / ggml_blck_size(real->type),
ggml_nbytes(real));
*dst_tensor = real;
return true;
}
}
}
}
}

for (auto& ignore_tensor : ignore_tensors) {
if (starts_with(name, ignore_tensor)) {
return true;
Expand All @@ -1633,6 +1821,29 @@ bool ModelLoader::load_tensors(std::map<std::string, ggml_tensor*>& tensors,
real->ne[1] != tensor_storage.ne[1] ||
real->ne[2] != tensor_storage.ne[2] ||
real->ne[3] != tensor_storage.ne[3]) {
// Check if this is the base (part 0) of a split tensor group
bool is_split_base = false;
if (tensor_storage.ne[0] == real->ne[0] &&
real->ne[1] > tensor_storage.ne[1] &&
real->ne[1] % tensor_storage.ne[1] == 0) {
std::string name_1 = name + ".1";
if (tensor_storage_map.find(name_1) != tensor_storage_map.end()) {
is_split_base = true;
}
}

if (is_split_base) {
tensor_storage.dst_offset = 0;
LOG_INFO("split-fuse base: '%s', file_ne=[%lld,%lld], dst_ne=[%lld,%lld], dst_offset=0, part_bytes=%zu, dst_total=%zu",
name.c_str(),
(long long)tensor_storage.ne[0], (long long)tensor_storage.ne[1],
(long long)real->ne[0], (long long)real->ne[1],
(size_t)tensor_storage.nelements() * ggml_type_size(real->type) / ggml_blck_size(real->type),
ggml_nbytes(real));
*dst_tensor = real;
return true;
}

LOG_ERROR(
"tensor '%s' has wrong shape in model file: "
"got [%d, %d, %d, %d], expected [%d, %d, %d, %d]",
Expand Down
2 changes: 2 additions & 0 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ struct TensorStorage {
size_t file_index = 0;
int index_in_zip = -1; // >= means stored in a zip file
uint64_t offset = 0; // offset in file
mutable size_t dst_offset = 0; // byte offset within destination tensor for split→fuse loading

TensorStorage() = default;

Expand Down Expand Up @@ -306,6 +307,7 @@ class ModelLoader {

bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_safetensors_index(const std::string& index_path, const std::string& prefix = "");
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");

Expand Down
Loading