|
import numpy as np |
|
import glob |
|
import os |
|
import pickle |
|
import lmdb |
|
import pyarrow |
|
import fasttext |
|
from loguru import logger |
|
from scipy import linalg |
|
|
|
|
|
class Vocab: |
|
PAD_token = 0 |
|
SOS_token = 1 |
|
EOS_token = 2 |
|
UNK_token = 3 |
|
|
|
def __init__(self, name, insert_default_tokens=True): |
|
self.name = name |
|
self.trimmed = False |
|
self.word_embedding_weights = None |
|
self.reset_dictionary(insert_default_tokens) |
|
|
|
def reset_dictionary(self, insert_default_tokens=True): |
|
self.word2index = {} |
|
self.word2count = {} |
|
if insert_default_tokens: |
|
self.index2word = {self.PAD_token: "<PAD>", self.SOS_token: "<SOS>", |
|
self.EOS_token: "<EOS>", self.UNK_token: "<UNK>"} |
|
else: |
|
self.index2word = {self.UNK_token: "<UNK>"} |
|
self.n_words = len(self.index2word) |
|
|
|
def index_word(self, word): |
|
if word not in self.word2index: |
|
self.word2index[word] = self.n_words |
|
self.word2count[word] = 1 |
|
self.index2word[self.n_words] = word |
|
self.n_words += 1 |
|
else: |
|
self.word2count[word] += 1 |
|
|
|
def add_vocab(self, other_vocab): |
|
for word, _ in other_vocab.word2count.items(): |
|
self.index_word(word) |
|
|
|
|
|
def trim(self, min_count): |
|
if self.trimmed: |
|
return |
|
self.trimmed = True |
|
|
|
keep_words = [] |
|
|
|
for k, v in self.word2count.items(): |
|
if v >= min_count: |
|
keep_words.append(k) |
|
|
|
print(' word trimming, kept %s / %s = %.4f' % ( |
|
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) |
|
)) |
|
|
|
|
|
self.reset_dictionary() |
|
for word in keep_words: |
|
self.index_word(word) |
|
|
|
def get_word_index(self, word): |
|
if word in self.word2index: |
|
return self.word2index[word] |
|
else: |
|
return self.UNK_token |
|
|
|
def load_word_vectors(self, pretrained_path, embedding_dim=300): |
|
print(" loading word vectors from '{}'...".format(pretrained_path)) |
|
|
|
|
|
init_sd = 1 / np.sqrt(embedding_dim) |
|
weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) |
|
weights = weights.astype(np.float32) |
|
|
|
|
|
word_model = fasttext.load_model(pretrained_path) |
|
for word, id in self.word2index.items(): |
|
vec = word_model.get_word_vector(word) |
|
weights[id] = vec |
|
self.word_embedding_weights = weights |
|
|
|
def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None): |
|
print(' building a language model...') |
|
lang_model = Vocab(name) |
|
print(' indexing words from {}'.format(data_path)) |
|
index_words_from_textgrid(lang_model, data_path) |
|
|
|
if word_vec_path is not None: |
|
lang_model.load_word_vectors(word_vec_path, feat_dim) |
|
else: |
|
print(' loaded from {}'.format(cache_path)) |
|
with open(cache_path, 'rb') as f: |
|
lang_model = pickle.load(f) |
|
if word_vec_path is None: |
|
lang_model.word_embedding_weights = None |
|
elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: |
|
logging.warning(' failed to load word embedding weights. check this') |
|
assert False |
|
|
|
with open(cache_path, 'wb') as f: |
|
pickle.dump(lang_model, f) |
|
|
|
return lang_model |
|
|
|
def index_words(lang_model, data_path): |
|
|
|
with open(data_path, "r") as f: |
|
for line in f.readlines(): |
|
line = line.replace(",", " ") |
|
line = line.replace(".", " ") |
|
line = line.replace("?", " ") |
|
line = line.replace("!", " ") |
|
for word in line.split(): |
|
lang_model.index_word(word) |
|
print(' indexed %d words' % lang_model.n_words) |
|
|
|
def index_words_from_textgrid(lang_model, data_path): |
|
import textgrid as tg |
|
trainvaltest=os.listdir(data_path) |
|
for loadtype in trainvaltest: |
|
if "." in loadtype: continue |
|
texts = os.listdir(data_path+loadtype+"/text/") |
|
for textfile in texts: |
|
tgrid = tg.TextGrid.fromFile(data_path+loadtype+"/text/"+textfile) |
|
for word in tgrid[0]: |
|
word_n, word_s, word_e = word.mark, word.minTime, word.maxTime |
|
word_n = word_n.replace(",", " ") |
|
word_n = word_n.replace(".", " ") |
|
word_n = word_n.replace("?", " ") |
|
word_n = word_n.replace("!", " ") |
|
|
|
lang_model.index_word(word_n) |
|
print(' indexed %d words' % lang_model.n_words) |
|
|
|
if __name__ == "__main__": |
|
|
|
build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300) |
|
|