File size: 2,401 Bytes
d3b6eff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os, sys, glob, json
from utils_lang import *
from transformers import AutoTokenizer

def get_kept_tids():
    # Keep all special tokens
    kept_tids = set( x for x in range(151643, 151664 + 1) )

    tokenizer = AutoTokenizer.from_pretrained(".")

    canbe_vi_kept = 0
    is_ascii_kept = 0

    for tid in range(0, tokenizer.vocab_size):
        token = tokenizer.decode(tid)

        if vietnamese_syllable_ratio(token) > 0.8:
            canbe_vi_kept += 1
            kept_tids.add(tid)

        if len(token) <= 2 and canbe_vietnamese(token):
            canbe_vi_kept += 1
            kept_tids.add(tid)

        if len(token) <= 2 and is_ascii(token):
            is_ascii_kept += 1
            kept_tids.add(tid)

    print(">>> canbe_vi_kept", canbe_vi_kept)
    print(">>> is_ascii_kept", is_ascii_kept)

    kept_filenames = glob.glob("data/qwen__1000__20000/tokens_kept__*.jsonl")

    for filename in kept_filenames:
        for line in open(filename, "rt"):
            token, tid, count = json.loads(line)
            kept_tids.add(tid)

    kept_tids = list( kept_tids )
    kept_tids.sort()

    print("new_qwen_vocab", len(kept_tids))
    return kept_tids


kept_tids = get_kept_tids()

# old vs new vocab mapping
old2new = {}
new2old = {}

for new_tid, old_tid in enumerate( kept_tids ):
    old2new[ old_tid ] = new_tid
    new2old[ new_tid ] = old_tid


STRANGE_TOKENS = set()

def old2new_tid(x, tokenizer):
    global STRANGE_TOKENS

    if x in old2new:
        return old2new[x]

    else:
        token = tokenizer.decode(x)
        if contains_unwanted(token):
            return None

        words = re.findall(r'[a-z]+', token, flags = re.IGNORECASE)

        if len(words) > 1:
            print(">>>", words)

        if len(words) == 1:
            tids = tokenizer.encode(words[0])
            if len(tids) == 1 and tids[0] in old2new:
                return old2new[tids[0]]

        msg = f">>> old2new_tid error: id {x}, token '{token}'"
        if token not in STRANGE_TOKENS:
            print(msg)
            STRANGE_TOKENS.add( token )

        # assert False, msg
        return None

    assert False, "Không thể tới bước này, có lỗi ở phần code trên"


if __name__ == "__main__":

    n = len(kept_tids)
    nn = round(n / 64) * 64

    print("kept_tids", n)
    print(n, nn) # 76138 => 76160 (làm tròn để chia hết cho 64)