Spaces:
Runtime error
Runtime error
# -*- encoding:utf-8 -*- | |
import os | |
from multiprocessing import Pool | |
from tencentpretrain.utils.constants import * | |
from tencentpretrain.utils.misc import count_lines | |
class Vocab(object): | |
""" | |
""" | |
def __init__(self): | |
self.w2i = {} | |
self.i2w = [] | |
self.w2c = {} | |
self.reserved_vocab_path = \ | |
os.path.abspath(os.path.join(os.path.dirname(__file__), "../../models/reserved_vocab.txt")) | |
def load(self, vocab_path, is_quiet=False): | |
with open(vocab_path, mode="r", encoding="utf-8") as reader: | |
for index, line in enumerate(reader): | |
w = line.strip("\r\n").split()[0] if line.strip() else line.strip("\r\n") | |
self.w2i[w] = index | |
self.i2w.append(w) | |
if not is_quiet: | |
print("Vocabulary size: ", len(self)) | |
def save(self, save_path): | |
print("Vocabulary size: ", len(self)) | |
with open(save_path, mode="w", encoding="utf-8") as f: | |
for w in self.i2w: | |
f.write(w + "\n") | |
print("Vocabulary saving done.") | |
def get(self, w): | |
return self.w2i[w] | |
def __len__(self): | |
return len(self.i2w) | |
def worker(self, corpus_path, tokenizer, start, end): | |
""" | |
Worker that creates vocabulary from corpus[start:end]. | |
""" | |
w2i, i2w, w2c = {}, [], {} | |
pos = 0 | |
with open(corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
# tokenizer is only either CharTokenizer or SpaceTokenizer | |
tokens = tokenizer.tokenize(line, use_vocab=False) | |
for t in tokens: | |
if t not in w2i: | |
w2i[t], w2c[t] = len(i2w), 1 | |
i2w.append(t) | |
else: | |
w2c[t] += 1 | |
if pos >= end - 1: | |
return (w2i, i2w, w2c) | |
def union(self, vocab_list): | |
""" Union vocab in all workers. """ | |
w2i, i2w, w2c = {}, [], {} | |
index = 0 | |
for v_p in vocab_list: | |
w2i_p, i2w_p, w2c_p = v_p.get() | |
for w in i2w_p: | |
if w not in w2i: | |
w2i[w], w2c[w] = len(i2w), w2c_p[w] | |
i2w.append(w) | |
else: | |
w2c[w] += w2c_p[w] | |
return (w2i, i2w, w2c) | |
def build(self, corpus_path, tokenizer, workers_num=1, min_count=1): | |
""" Build vocabulary from the given corpus. """ | |
print("Start %d workers for building vocabulary..." % workers_num) | |
lines_num = count_lines(corpus_path) | |
pool = Pool(workers_num) | |
vocab_list = [] | |
for i in range(workers_num): | |
start = i * lines_num // workers_num | |
end = (i+1) * lines_num // workers_num | |
vocab_list.append((pool.apply_async(func=self.worker, args=[corpus_path, tokenizer, start, end]))) | |
pool.close() | |
pool.join() | |
# Union vocab in all workers. | |
w2i, i2w, w2c = self.union(vocab_list) | |
# Sort w2c according to word count. | |
sorted_w2c = sorted(w2c.items(), key=lambda item:item[1], reverse=True) | |
# Add special symbols and remove low frequency words. | |
with open(self.reserved_vocab_path, mode="r", encoding="utf-8") as reader: | |
self.i2w = [line.strip().split()[0] for line in reader] | |
for i, w in enumerate(self.i2w): | |
self.w2i[w] = i | |
self.w2c[w] = -1 | |
for w, c in sorted_w2c: | |
if c < min_count: | |
break | |
if w not in self.w2i: | |
self.w2i[w], self.w2c[w] = len(self.i2w), c | |
self.i2w.append(w) | |