File size: 5,187 Bytes
cb9e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import json
import math
import os
from pathlib import Path

import torch
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer

from model.args import ModelArgs

FIRST_PIECE_ID = 3
OLD_VOCAB_SIZE = 32000
NEW_VOCAB_SIZE = 32768


def extend_model(original_model: Path, extended_model: Path):
    original_ckpt = torch.load(str(original_model / "consolidated.00.pth"), mmap=True)
    model_args = ModelArgs.load(str(original_model / "params.json"))

    original_vocab_size = model_args.vocab_size
    assert (
        original_vocab_size == OLD_VOCAB_SIZE
    ), f"Original vocab size {original_vocab_size} is not equal to 32000. Can only extend models with vocab_size of 32000"

    if not extended_model.exists():
        os.makedirs(extended_model, exist_ok=True)
        print(f"Created empty directory {extended_model}.")

    assert not list(
        extended_model.iterdir()
    ), f"Make sure {extended_model} is empty"

    # Load and check tokenizers
    mistral_tokenizer = MistralTokenizer.v3()
    tokenizer: SentencePieceTokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer

    new_vocab_size = tokenizer.n_words
    assert (
        new_vocab_size == 32768
    ), f"New Tokenzier has vocab_size: {new_vocab_size} but has to be equal to 32768. Make sure to pass a v2 or v3 tokenizer file"

    vocabulary_delta = new_vocab_size - original_vocab_size

    # Check that 0...FIRST_PIECE_ID-1 are UNK + control characters and FIRST_PIECE_ID is the first piece
    assert tokenizer._model.id_to_piece(vocabulary_delta + FIRST_PIECE_ID) == "<0x00>"
    assert tokenizer._model.id_to_piece(FIRST_PIECE_ID - 1) == "</s>"

    assert isinstance(tokenizer, SentencePieceTokenizer)

    original_embeddings = original_ckpt["tok_embeddings.weight"]

    assert (
        original_vocab_size == original_embeddings.shape[0]
    ), f"Original vocab size {original_vocab_size} is not equal to original embeddings shape {original_embeddings.shape[0]}."

    dim = original_embeddings.shape[1]

    # Extend embeddings
    extended_embeddings = torch.zeros(
        tokenizer.n_words, dim, dtype=original_embeddings.dtype
    )
    extended_embeddings[:original_vocab_size] = original_embeddings
    extended_embeddings[:FIRST_PIECE_ID] = original_embeddings[:FIRST_PIECE_ID]
    extended_embeddings[FIRST_PIECE_ID + vocabulary_delta :] = original_embeddings[
        FIRST_PIECE_ID:
    ]

    # randomly initialize new tokens
    extended_tokens = torch.empty(
        vocabulary_delta, dim, dtype=original_embeddings.dtype
    )
    torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim))

    extended_embeddings[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = (
        extended_tokens
    )

    # Extend output
    original_output = original_ckpt["output.weight"]
    assert (
        original_output.shape[0] == original_vocab_size
    ), f"Original output shape {original_output.shape[0]} is not equal to {original_vocab_size}."
    assert (
        original_output.shape[1] == dim
    ), f"Original output dim {original_output.shape[1]} is not equal to embedding dim {dim}."

    assert (
        original_output.dtype == original_embeddings.dtype
    ), f"Original output and embeddings have different dtypes: {original_output.dtype} vs {original_embeddings.dtype}."

    extended_output = torch.zeros(tokenizer.n_words, dim, dtype=original_output.dtype)
    extended_output[:FIRST_PIECE_ID] = original_output[:FIRST_PIECE_ID]
    extended_output[FIRST_PIECE_ID + vocabulary_delta :] = original_output[
        FIRST_PIECE_ID:
    ]

    # randomly initialize new tokens
    extended_tokens = torch.empty(vocabulary_delta, dim, dtype=original_output.dtype)
    torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim))

    extended_output[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = (
        extended_tokens
    )

    original_ckpt["tok_embeddings.weight"] = extended_embeddings
    original_ckpt["output.weight"] = extended_output

    new_ckpt_path = extended_model / "consolidated.00.pth"
    print(f"Exporting extended model to {extended_model} ...")
    torch.save(original_ckpt, new_ckpt_path)

    params_path = extended_model / "params.json"
    with open(params_path, "w") as f:
        model_dict = model_args.to_dict()
        del model_dict["lora"]
        if model_dict["moe"] is None:
            del model_dict["moe"]
        model_dict["vocab_size"] = new_vocab_size

        f.write(json.dumps(model_dict, indent=4))


def main():
    parser = argparse.ArgumentParser(
        description="Extend a model using the specified original model, extended model, and tokenizer paths."
    )
    parser.add_argument(
        "--original_model_ckpt", type=Path, help="Path to the original model folder."
    )
    parser.add_argument(
        "--extended_model_ckpt", type=Path, help="Path to the extended model file."
    )
    args = parser.parse_args()

    extend_model(
        original_model=args.original_model_ckpt,
        extended_model=args.extended_model_ckpt,
    )


if __name__ == "__main__":
    main()