Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 189 additions & 5 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "e4d54df1ebc1f2b91acd986c5b51aa50837d5faf7c7398e73c1f9e9ee5d19869":
# ref: https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601
res = "kanana2"
if chkhsh == "5f9861fd826d8e124b222f41f41b928e78d8f6c8fbdf25625d06cc1e8736662c":
# ref: https://huggingface.co/OpenLLM-France/Luciole-1B-Base
res = "qwen2"

if res is None:
logger.warning("\n")
Expand All @@ -1515,15 +1518,179 @@ def get_vocab_base_pre(self, tokenizer) -> str:
def _set_vocab_none(self) -> None:
self.gguf_writer.add_tokenizer_model("none")

def _set_vocab_gpt2(self) -> None:
@staticmethod
def _gpt2_bytes_to_unicode() -> dict[int, str]:
# Returns the GPT-2 byte-to-unicode mapping: each byte (0-255) maps to a
# printable unicode character. Printable ASCII and Latin-1 supplement bytes
# map to themselves; remaining bytes are shifted to 256+.
# This is the same as openai/gpt-2's bytes_to_unicode().
bs = list(range(ord("!"), ord("~") + 1)) + list(range(0xA1, 0xAC + 1)) + list(range(0xAE, 0xFF + 1))
cs = list(bs)
n = 0
for b in range(256):
if b not in bs:
bs.append(b)
cs.append(256 + n)
n += 1
return dict(zip(bs, (chr(c) for c in cs)))

def _set_vocab_gpt2(self, convert_metaspace_to_gpt2=False) -> None:
tokens, toktypes, tokpre = self.get_vocab_base()

if convert_metaspace_to_gpt2:
# The tokenizer uses raw UTF-8 with Metaspace (▁ for spaces), but
# the "gpt2" tokenizer model in llama.cpp expects GPT-2 byte encoding
# (where each byte is mapped to a printable unicode char, e.g. space -> Ġ).
# Convert all tokens: replace ▁ back to space, then apply GPT-2 byte encoding.
byte_encoder = self._gpt2_bytes_to_unicode()
seen: set[str] = set()
for i, token in enumerate(tokens):
if toktypes[i] in (gguf.TokenType.NORMAL, gguf.TokenType.USER_DEFINED):
if token == " ":
# Useless token in Luciole
encoded = "".join(byte_encoder[b] for b in "\u2581".encode("utf-8"))
else:
encoded = "".join(byte_encoder[b] for b in token.replace("\u2581", " ").encode("utf-8"))
assert encoded not in seen, f"Unexpected collision in GPT-2 byte encoding: {encoded!r} for '{token}'"
seen.add(encoded)
tokens[i] = encoded
else: # gguf.TokenType.CONTROL
print("NOCOMMIT", i, token, toktypes[i])
assert token not in seen, f"Unexpected collision in GPT-2 byte encoding: {token}"
seen.add(token)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
if convert_metaspace_to_gpt2:
special_vocab.merges = [
" ".join(
"".join(byte_encoder[b] for b in part.replace("\u2581", " ").encode("utf-8"))
for part in merge.split(" ")
)
for merge in special_vocab.merges
]
special_vocab.add_to_gguf(self.gguf_writer)
return tokens

def _set_vocab_bpe_as_spm(self) -> None:
"""Convert a HuggingFace BPE tokenizer (with Metaspace ▁) to SPM format for llama.cpp.

This reads the vocab from tokenizer.json, keeps tokens in their original
UTF-8 form (with ▁ preserved), assigns scores from merge ranks, and adds
byte fallback tokens <0x00>-<0xFF> required by the SPM tokenizer in C++.
"""
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))

reverse_vocab = {id_: tok for tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder

# Build merge rank lookup: token_text -> rank (lower rank = merged earlier = higher priority)
merge_ranks: dict[str, int] = {}
merges_file = self.dir_model / "tokenizer.json"
if merges_file.is_file():
import json as _json
with open(merges_file, "r", encoding="utf-8") as f:
tokenizer_json = _json.load(f)
merges = tokenizer_json.get("model", {}).get("merges", [])
for rank, merge in enumerate(merges):
# merge can be "token_a token_b" (str) or ["token_a", "token_b"] (list)
parts = merge.split(" ") if isinstance(merge, str) else merge
merged_token = "".join(parts)
if merged_token not in merge_ranks:
merge_ranks[merged_token] = rank

# Prepare token arrays
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size

# Track which byte values are covered (for byte fallback)
byte_token_ids: dict[int, int] = {}

for token_id in range(vocab_size):
if token_id not in reverse_vocab:
continue

token_text = reverse_vocab[token_id]

if token_id in added_tokens_decoder:
info = added_tokens_decoder[token_id]
if info.special or self.does_token_look_special(token_text):
tokens[token_id] = token_text.encode("utf-8")
scores[token_id] = 0.0
toktypes[token_id] = SentencePieceTokenTypes.CONTROL
continue

# Check if this is a byte fallback token (<0xHH>) or a single-byte token
import re as _re
raw_bytes = token_text.encode("utf-8")
byte_match = _re.fullmatch(r"<0x([0-9A-Fa-f]{2})>", token_text)
if byte_match:
byte_val = int(byte_match.group(1), 16)
byte_token_ids[byte_val] = token_id
tokens[token_id] = token_text.encode("utf-8")
scores[token_id] = -10000.0
toktypes[token_id] = SentencePieceTokenTypes.BYTE
continue
elif len(raw_bytes) == 1:
byte_token_ids[raw_bytes[0]] = token_id

# Assign score based on merge rank or token_id
if token_text in merge_ranks:
# Merged tokens: earlier merges get higher (less negative) scores
# Use negative rank so that rank 0 (first merge) gets highest score
score = -float(merge_ranks[token_text])
else:
# Base tokens (single chars) get high scores; unknown tokens get low scores
if len(raw_bytes) == 1:
score = 0.0
else:
score = -10000.0 + float(token_id)

tokens[token_id] = raw_bytes
scores[token_id] = score
toktypes[token_id] = SentencePieceTokenTypes.NORMAL

# Add byte fallback tokens for any missing byte values
# SPM in llama.cpp requires <0x00> through <0xFF> with BYTE type
next_pad_idx = 0
for byte_val in range(256):
if byte_val in byte_token_ids:
continue # already handled above
hex_str = f"<0x{byte_val:02X}>"
if byte_val in byte_token_ids:
tid = byte_token_ids[byte_val]
tokens[tid] = hex_str.encode("utf-8")
toktypes[tid] = SentencePieceTokenTypes.BYTE
scores[tid] = -10000.0
else:
# Find an unused PAD slot
while next_pad_idx < len(tokens) and toktypes[next_pad_idx] != SentencePieceTokenTypes.UNUSED:
next_pad_idx += 1
if next_pad_idx < vocab_size:
tokens[next_pad_idx] = hex_str.encode("utf-8")
toktypes[next_pad_idx] = SentencePieceTokenTypes.BYTE
scores[next_pad_idx] = -10000.0
next_pad_idx += 1
else:
logger.warning(f"No room to add byte fallback token {hex_str}")

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
return tokens

def _set_vocab_qwen(self):
dir_model = self.dir_model
Expand Down Expand Up @@ -9607,14 +9774,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


LUCIOLE_TO_BPE = False
@ModelBase.register("NemotronForCausalLM")
class NemotronModel(TextModel):
model_arch = gguf.MODEL_ARCH.NEMOTRON

def set_vocab(self):
self._set_vocab_sentencepiece()
self.gguf_writer.add_pad_token_id(0)
self.gguf_writer.add_unk_token_id(1)
if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
self.gguf_writer.add_pad_token_id(0)
self.gguf_writer.add_unk_token_id(1)
else:
# Luciole
if LUCIOLE_TO_BPE:
tokens = self._set_vocab_gpt2(convert_metaspace_to_gpt2=True)
self.gguf_writer.add_pad_token_id(tokens.index("<pad>"))
self.gguf_writer.add_unk_token_id(tokens.index("<unk>"))
else:
tokens = self._set_vocab_bpe_as_spm()
self.gguf_writer.add_pad_token_id(tokens.index(b"<pad>"))
self.gguf_writer.add_unk_token_id(tokens.index(b"<unk>"))
self.gguf_writer.add_add_space_prefix(True)

def set_gguf_parameters(self):
super().set_gguf_parameters()
Expand Down Expand Up @@ -9645,6 +9825,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if name.endswith("norm.weight"):
data_torch = data_torch + 1

# for tied embeddings, duplicate token_embd as output.weight
if self.hparams.get("tie_word_embeddings", False) and name == "model.embed_tokens.weight":
yield (self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch)

yield from super().modify_tensors(data_torch, name, bid)


Expand Down
Loading