|
import argparse |
|
import json |
|
import math |
|
import os |
|
from pathlib import Path |
|
|
|
import torch |
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
|
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer |
|
|
|
from model.args import ModelArgs |
|
|
|
FIRST_PIECE_ID = 3 |
|
OLD_VOCAB_SIZE = 32000 |
|
NEW_VOCAB_SIZE = 32768 |
|
|
|
|
|
def extend_model(original_model: Path, extended_model: Path): |
|
original_ckpt = torch.load(str(original_model / "consolidated.00.pth"), mmap=True) |
|
model_args = ModelArgs.load(str(original_model / "params.json")) |
|
|
|
original_vocab_size = model_args.vocab_size |
|
assert ( |
|
original_vocab_size == OLD_VOCAB_SIZE |
|
), f"Original vocab size {original_vocab_size} is not equal to 32000. Can only extend models with vocab_size of 32000" |
|
|
|
if not extended_model.exists(): |
|
os.makedirs(extended_model, exist_ok=True) |
|
print(f"Created empty directory {extended_model}.") |
|
|
|
assert not list( |
|
extended_model.iterdir() |
|
), f"Make sure {extended_model} is empty" |
|
|
|
|
|
mistral_tokenizer = MistralTokenizer.v3() |
|
tokenizer: SentencePieceTokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer |
|
|
|
new_vocab_size = tokenizer.n_words |
|
assert ( |
|
new_vocab_size == 32768 |
|
), f"New Tokenzier has vocab_size: {new_vocab_size} but has to be equal to 32768. Make sure to pass a v2 or v3 tokenizer file" |
|
|
|
vocabulary_delta = new_vocab_size - original_vocab_size |
|
|
|
|
|
assert tokenizer._model.id_to_piece(vocabulary_delta + FIRST_PIECE_ID) == "<0x00>" |
|
assert tokenizer._model.id_to_piece(FIRST_PIECE_ID - 1) == "</s>" |
|
|
|
assert isinstance(tokenizer, SentencePieceTokenizer) |
|
|
|
original_embeddings = original_ckpt["tok_embeddings.weight"] |
|
|
|
assert ( |
|
original_vocab_size == original_embeddings.shape[0] |
|
), f"Original vocab size {original_vocab_size} is not equal to original embeddings shape {original_embeddings.shape[0]}." |
|
|
|
dim = original_embeddings.shape[1] |
|
|
|
|
|
extended_embeddings = torch.zeros( |
|
tokenizer.n_words, dim, dtype=original_embeddings.dtype |
|
) |
|
extended_embeddings[:original_vocab_size] = original_embeddings |
|
extended_embeddings[:FIRST_PIECE_ID] = original_embeddings[:FIRST_PIECE_ID] |
|
extended_embeddings[FIRST_PIECE_ID + vocabulary_delta :] = original_embeddings[ |
|
FIRST_PIECE_ID: |
|
] |
|
|
|
|
|
extended_tokens = torch.empty( |
|
vocabulary_delta, dim, dtype=original_embeddings.dtype |
|
) |
|
torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim)) |
|
|
|
extended_embeddings[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = ( |
|
extended_tokens |
|
) |
|
|
|
|
|
original_output = original_ckpt["output.weight"] |
|
assert ( |
|
original_output.shape[0] == original_vocab_size |
|
), f"Original output shape {original_output.shape[0]} is not equal to {original_vocab_size}." |
|
assert ( |
|
original_output.shape[1] == dim |
|
), f"Original output dim {original_output.shape[1]} is not equal to embedding dim {dim}." |
|
|
|
assert ( |
|
original_output.dtype == original_embeddings.dtype |
|
), f"Original output and embeddings have different dtypes: {original_output.dtype} vs {original_embeddings.dtype}." |
|
|
|
extended_output = torch.zeros(tokenizer.n_words, dim, dtype=original_output.dtype) |
|
extended_output[:FIRST_PIECE_ID] = original_output[:FIRST_PIECE_ID] |
|
extended_output[FIRST_PIECE_ID + vocabulary_delta :] = original_output[ |
|
FIRST_PIECE_ID: |
|
] |
|
|
|
|
|
extended_tokens = torch.empty(vocabulary_delta, dim, dtype=original_output.dtype) |
|
torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim)) |
|
|
|
extended_output[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = ( |
|
extended_tokens |
|
) |
|
|
|
original_ckpt["tok_embeddings.weight"] = extended_embeddings |
|
original_ckpt["output.weight"] = extended_output |
|
|
|
new_ckpt_path = extended_model / "consolidated.00.pth" |
|
print(f"Exporting extended model to {extended_model} ...") |
|
torch.save(original_ckpt, new_ckpt_path) |
|
|
|
params_path = extended_model / "params.json" |
|
with open(params_path, "w") as f: |
|
model_dict = model_args.to_dict() |
|
del model_dict["lora"] |
|
if model_dict["moe"] is None: |
|
del model_dict["moe"] |
|
model_dict["vocab_size"] = new_vocab_size |
|
|
|
f.write(json.dumps(model_dict, indent=4)) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Extend a model using the specified original model, extended model, and tokenizer paths." |
|
) |
|
parser.add_argument( |
|
"--original_model_ckpt", type=Path, help="Path to the original model folder." |
|
) |
|
parser.add_argument( |
|
"--extended_model_ckpt", type=Path, help="Path to the extended model file." |
|
) |
|
args = parser.parse_args() |
|
|
|
extend_model( |
|
original_model=args.original_model_ckpt, |
|
extended_model=args.extended_model_ckpt, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|