EMAGE / dataloaders /build_vocab.py
H-Liu1997's picture
Upload folder using huggingface_hub
2d47d90 verified
raw
history blame
7.68 kB
import numpy as np
import glob
import os
import pickle
import lmdb
import pyarrow
import fasttext
from loguru import logger
from scipy import linalg
class Vocab:
PAD_token = 0
SOS_token = 1
EOS_token = 2
UNK_token = 3
def __init__(self, name, insert_default_tokens=True):
self.name = name
self.trimmed = False
self.word_embedding_weights = None
self.reset_dictionary(insert_default_tokens)
def reset_dictionary(self, insert_default_tokens=True):
self.word2index = {}
self.word2count = {}
if insert_default_tokens:
self.index2word = {self.PAD_token: "<PAD>", self.SOS_token: "<SOS>",
self.EOS_token: "<EOS>", self.UNK_token: "<UNK>"}
else:
self.index2word = {self.UNK_token: "<UNK>"}
self.n_words = len(self.index2word) # count default tokens
def index_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def add_vocab(self, other_vocab):
for word, _ in other_vocab.word2count.items():
self.index_word(word)
# remove words below a certain count threshold
def trim(self, min_count):
if self.trimmed:
return
self.trimmed = True
keep_words = []
for k, v in self.word2count.items():
if v >= min_count:
keep_words.append(k)
print(' word trimming, kept %s / %s = %.4f' % (
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
))
# reinitialize dictionary
self.reset_dictionary()
for word in keep_words:
self.index_word(word)
def get_word_index(self, word):
if word in self.word2index:
return self.word2index[word]
else:
return self.UNK_token
def load_word_vectors(self, pretrained_path, embedding_dim=300):
print(" loading word vectors from '{}'...".format(pretrained_path))
# initialize embeddings to random values for special words
init_sd = 1 / np.sqrt(embedding_dim)
weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim])
weights = weights.astype(np.float32)
# read word vectors
word_model = fasttext.load_model(pretrained_path)
for word, id in self.word2index.items():
vec = word_model.get_word_vector(word)
weights[id] = vec
self.word_embedding_weights = weights
def __get_embedding_weight(self, pretrained_path, embedding_dim=300):
""" function modified from http://ronny.rest/blog/post_2017_08_04_glove/ """
print("Loading word embedding '{}'...".format(pretrained_path))
cache_path = pretrained_path
weights = None
# use cached file if it exists
if os.path.exists(cache_path): #
with open(cache_path, 'rb') as f:
print(' using cached result from {}'.format(cache_path))
weights = pickle.load(f)
if weights.shape != (self.n_words, embedding_dim):
logging.warning(' failed to load word embedding weights. reinitializing...')
weights = None
if weights is None:
# initialize embeddings to random values for special and OOV words
init_sd = 1 / np.sqrt(embedding_dim)
weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim])
weights = weights.astype(np.float32)
with open(pretrained_path, encoding="utf-8", mode="r") as textFile:
num_embedded_words = 0
for line_raw in textFile:
# extract the word, and embeddings vector
line = line_raw.split()
try:
word, vector = (line[0], np.array(line[1:], dtype=np.float32))
# if word == 'love': # debugging
# print(word, vector)
# if it is in our vocab, then update the corresponding weights
id = self.word2index.get(word, None)
if id is not None:
weights[id] = vector
num_embedded_words += 1
except ValueError:
print(' parsing error at {}...'.format(line_raw[:50]))
continue
print(' {} / {} word vectors are found in the embedding'.format(num_embedded_words, len(self.word2index)))
with open(cache_path, 'wb') as f:
pickle.dump(weights, f)
return weights
def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None):
print(' building a language model...')
#if not os.path.exists(cache_path):
lang_model = Vocab(name)
print(' indexing words from {}'.format(data_path))
index_words_from_textgrid(lang_model, data_path)
if word_vec_path is not None:
lang_model.load_word_vectors(word_vec_path, feat_dim)
else:
print(' loaded from {}'.format(cache_path))
with open(cache_path, 'rb') as f:
lang_model = pickle.load(f)
if word_vec_path is None:
lang_model.word_embedding_weights = None
elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words:
logging.warning(' failed to load word embedding weights. check this')
assert False
with open(cache_path, 'wb') as f:
pickle.dump(lang_model, f)
return lang_model
def index_words(lang_model, data_path):
#index words form text
with open(data_path, "r") as f:
for line in f.readlines():
line = line.replace(",", " ")
line = line.replace(".", " ")
line = line.replace("?", " ")
line = line.replace("!", " ")
for word in line.split():
lang_model.index_word(word)
print(' indexed %d words' % lang_model.n_words)
def index_words_from_textgrid(lang_model, data_path):
import textgrid as tg
from tqdm import tqdm
#trainvaltest=os.listdir(data_path)
# for loadtype in trainvaltest:
# if "." in loadtype: continue #ignore .ipynb_checkpoints
texts = os.listdir(data_path+"/textgrid/")
#print(texts)
for textfile in tqdm(texts):
tgrid = tg.TextGrid.fromFile(data_path+"/textgrid/"+textfile)
for word in tgrid[0]:
word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
word_n = word_n.replace(",", " ")
word_n = word_n.replace(".", " ")
word_n = word_n.replace("?", " ")
word_n = word_n.replace("!", " ")
#print(word_n)
lang_model.index_word(word_n)
print(' indexed %d words' % lang_model.n_words)
print(lang_model.word2index, lang_model.word2count)
if __name__ == "__main__":
# 11195 for all, 5793 for 4 speakers
# build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300)
build_vocab("beat_chinese_v1.0.0", "/data/datasets/beat_chinese_v1.0.0/", "/data/datasets/beat_chinese_v1.0.0/weights/vocab.pkl", "/home/ma-user/work/cc.zh.300.bin", 300)