File size: 2,401 Bytes
d3b6eff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os, sys, glob, json
from utils_lang import *
from transformers import AutoTokenizer
def get_kept_tids():
# Keep all special tokens
kept_tids = set( x for x in range(151643, 151664 + 1) )
tokenizer = AutoTokenizer.from_pretrained(".")
canbe_vi_kept = 0
is_ascii_kept = 0
for tid in range(0, tokenizer.vocab_size):
token = tokenizer.decode(tid)
if vietnamese_syllable_ratio(token) > 0.8:
canbe_vi_kept += 1
kept_tids.add(tid)
if len(token) <= 2 and canbe_vietnamese(token):
canbe_vi_kept += 1
kept_tids.add(tid)
if len(token) <= 2 and is_ascii(token):
is_ascii_kept += 1
kept_tids.add(tid)
print(">>> canbe_vi_kept", canbe_vi_kept)
print(">>> is_ascii_kept", is_ascii_kept)
kept_filenames = glob.glob("data/qwen__1000__20000/tokens_kept__*.jsonl")
for filename in kept_filenames:
for line in open(filename, "rt"):
token, tid, count = json.loads(line)
kept_tids.add(tid)
kept_tids = list( kept_tids )
kept_tids.sort()
print("new_qwen_vocab", len(kept_tids))
return kept_tids
kept_tids = get_kept_tids()
# old vs new vocab mapping
old2new = {}
new2old = {}
for new_tid, old_tid in enumerate( kept_tids ):
old2new[ old_tid ] = new_tid
new2old[ new_tid ] = old_tid
STRANGE_TOKENS = set()
def old2new_tid(x, tokenizer):
global STRANGE_TOKENS
if x in old2new:
return old2new[x]
else:
token = tokenizer.decode(x)
if contains_unwanted(token):
return None
words = re.findall(r'[a-z]+', token, flags = re.IGNORECASE)
if len(words) > 1:
print(">>>", words)
if len(words) == 1:
tids = tokenizer.encode(words[0])
if len(tids) == 1 and tids[0] in old2new:
return old2new[tids[0]]
msg = f">>> old2new_tid error: id {x}, token '{token}'"
if token not in STRANGE_TOKENS:
print(msg)
STRANGE_TOKENS.add( token )
# assert False, msg
return None
assert False, "Không thể tới bước này, có lỗi ở phần code trên"
if __name__ == "__main__":
n = len(kept_tids)
nn = round(n / 64) * 64
print("kept_tids", n)
print(n, nn) # 76138 => 76160 (làm tròn để chia hết cho 64)
|