File size: 5,187 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
import json
import math
import os
from pathlib import Path
import torch
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer
from model.args import ModelArgs
FIRST_PIECE_ID = 3
OLD_VOCAB_SIZE = 32000
NEW_VOCAB_SIZE = 32768
def extend_model(original_model: Path, extended_model: Path):
original_ckpt = torch.load(str(original_model / "consolidated.00.pth"), mmap=True)
model_args = ModelArgs.load(str(original_model / "params.json"))
original_vocab_size = model_args.vocab_size
assert (
original_vocab_size == OLD_VOCAB_SIZE
), f"Original vocab size {original_vocab_size} is not equal to 32000. Can only extend models with vocab_size of 32000"
if not extended_model.exists():
os.makedirs(extended_model, exist_ok=True)
print(f"Created empty directory {extended_model}.")
assert not list(
extended_model.iterdir()
), f"Make sure {extended_model} is empty"
# Load and check tokenizers
mistral_tokenizer = MistralTokenizer.v3()
tokenizer: SentencePieceTokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
new_vocab_size = tokenizer.n_words
assert (
new_vocab_size == 32768
), f"New Tokenzier has vocab_size: {new_vocab_size} but has to be equal to 32768. Make sure to pass a v2 or v3 tokenizer file"
vocabulary_delta = new_vocab_size - original_vocab_size
# Check that 0...FIRST_PIECE_ID-1 are UNK + control characters and FIRST_PIECE_ID is the first piece
assert tokenizer._model.id_to_piece(vocabulary_delta + FIRST_PIECE_ID) == "<0x00>"
assert tokenizer._model.id_to_piece(FIRST_PIECE_ID - 1) == "</s>"
assert isinstance(tokenizer, SentencePieceTokenizer)
original_embeddings = original_ckpt["tok_embeddings.weight"]
assert (
original_vocab_size == original_embeddings.shape[0]
), f"Original vocab size {original_vocab_size} is not equal to original embeddings shape {original_embeddings.shape[0]}."
dim = original_embeddings.shape[1]
# Extend embeddings
extended_embeddings = torch.zeros(
tokenizer.n_words, dim, dtype=original_embeddings.dtype
)
extended_embeddings[:original_vocab_size] = original_embeddings
extended_embeddings[:FIRST_PIECE_ID] = original_embeddings[:FIRST_PIECE_ID]
extended_embeddings[FIRST_PIECE_ID + vocabulary_delta :] = original_embeddings[
FIRST_PIECE_ID:
]
# randomly initialize new tokens
extended_tokens = torch.empty(
vocabulary_delta, dim, dtype=original_embeddings.dtype
)
torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim))
extended_embeddings[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = (
extended_tokens
)
# Extend output
original_output = original_ckpt["output.weight"]
assert (
original_output.shape[0] == original_vocab_size
), f"Original output shape {original_output.shape[0]} is not equal to {original_vocab_size}."
assert (
original_output.shape[1] == dim
), f"Original output dim {original_output.shape[1]} is not equal to embedding dim {dim}."
assert (
original_output.dtype == original_embeddings.dtype
), f"Original output and embeddings have different dtypes: {original_output.dtype} vs {original_embeddings.dtype}."
extended_output = torch.zeros(tokenizer.n_words, dim, dtype=original_output.dtype)
extended_output[:FIRST_PIECE_ID] = original_output[:FIRST_PIECE_ID]
extended_output[FIRST_PIECE_ID + vocabulary_delta :] = original_output[
FIRST_PIECE_ID:
]
# randomly initialize new tokens
extended_tokens = torch.empty(vocabulary_delta, dim, dtype=original_output.dtype)
torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim))
extended_output[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = (
extended_tokens
)
original_ckpt["tok_embeddings.weight"] = extended_embeddings
original_ckpt["output.weight"] = extended_output
new_ckpt_path = extended_model / "consolidated.00.pth"
print(f"Exporting extended model to {extended_model} ...")
torch.save(original_ckpt, new_ckpt_path)
params_path = extended_model / "params.json"
with open(params_path, "w") as f:
model_dict = model_args.to_dict()
del model_dict["lora"]
if model_dict["moe"] is None:
del model_dict["moe"]
model_dict["vocab_size"] = new_vocab_size
f.write(json.dumps(model_dict, indent=4))
def main():
parser = argparse.ArgumentParser(
description="Extend a model using the specified original model, extended model, and tokenizer paths."
)
parser.add_argument(
"--original_model_ckpt", type=Path, help="Path to the original model folder."
)
parser.add_argument(
"--extended_model_ckpt", type=Path, help="Path to the extended model file."
)
args = parser.parse_args()
extend_model(
original_model=args.original_model_ckpt,
extended_model=args.extended_model_ckpt,
)
if __name__ == "__main__":
main()
|