|
|
|
|
|
from __future__ import division |
|
import six |
|
import argparse |
|
import torch |
|
from onmt.utils.logging import init_logger, logger |
|
|
|
|
|
def get_vocabs(dict_path): |
|
fields = torch.load(dict_path) |
|
|
|
vocs = [] |
|
for side in ['src', 'tgt']: |
|
try: |
|
vocab = fields[side].base_field.vocab |
|
except AttributeError: |
|
vocab = fields[side].vocab |
|
vocs.append(vocab) |
|
enc_vocab, dec_vocab = vocs |
|
|
|
logger.info("From: %s" % dict_path) |
|
logger.info("\t* source vocab: %d words" % len(enc_vocab)) |
|
logger.info("\t* target vocab: %d words" % len(dec_vocab)) |
|
|
|
return enc_vocab, dec_vocab |
|
|
|
|
|
def read_embeddings(file_enc, skip_lines=0, filter_set=None): |
|
embs = dict() |
|
total_vectors_in_file = 0 |
|
with open(file_enc, 'rb') as f: |
|
for i, line in enumerate(f): |
|
if i < skip_lines: |
|
continue |
|
if not line: |
|
break |
|
if len(line) == 0: |
|
|
|
continue |
|
|
|
l_split = line.decode('utf8').strip().split(' ') |
|
if len(l_split) == 2: |
|
continue |
|
total_vectors_in_file += 1 |
|
if filter_set is not None and l_split[0] not in filter_set: |
|
continue |
|
embs[l_split[0]] = [float(em) for em in l_split[1:]] |
|
return embs, total_vectors_in_file |
|
|
|
|
|
def convert_to_torch_tensor(word_to_float_list_dict, vocab): |
|
dim = len(six.next(six.itervalues(word_to_float_list_dict))) |
|
tensor = torch.zeros((len(vocab), dim)) |
|
for word, values in word_to_float_list_dict.items(): |
|
tensor[vocab.stoi[word]] = torch.Tensor(values) |
|
return tensor |
|
|
|
|
|
def calc_vocab_load_stats(vocab, loaded_embed_dict): |
|
matching_count = len( |
|
set(vocab.stoi.keys()) & set(loaded_embed_dict.keys())) |
|
missing_count = len(vocab) - matching_count |
|
percent_matching = matching_count / len(vocab) * 100 |
|
return matching_count, missing_count, percent_matching |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='embeddings_to_torch.py') |
|
parser.add_argument('-emb_file_both', required=False, |
|
help="loads Embeddings for both source and target " |
|
"from this file.") |
|
parser.add_argument('-emb_file_enc', required=False, |
|
help="source Embeddings from this file") |
|
parser.add_argument('-emb_file_dec', required=False, |
|
help="target Embeddings from this file") |
|
parser.add_argument('-output_file', required=True, |
|
help="Output file for the prepared data") |
|
parser.add_argument('-dict_file', required=True, |
|
help="Dictionary file") |
|
parser.add_argument('-verbose', action="store_true", default=False) |
|
parser.add_argument('-skip_lines', type=int, default=0, |
|
help="Skip first lines of the embedding file") |
|
parser.add_argument('-type', choices=["GloVe", "word2vec"], |
|
default="GloVe") |
|
opt = parser.parse_args() |
|
|
|
enc_vocab, dec_vocab = get_vocabs(opt.dict_file) |
|
|
|
|
|
skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines |
|
if opt.emb_file_both is not None: |
|
if opt.emb_file_enc is not None: |
|
raise ValueError("If --emb_file_both is passed in, you should not" |
|
"set --emb_file_enc.") |
|
if opt.emb_file_dec is not None: |
|
raise ValueError("If --emb_file_both is passed in, you should not" |
|
"set --emb_file_dec.") |
|
set_of_src_and_tgt_vocab = \ |
|
set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys()) |
|
logger.info("Reading encoder and decoder embeddings from {}".format( |
|
opt.emb_file_both)) |
|
src_vectors, total_vec_count = \ |
|
read_embeddings(opt.emb_file_both, skip_lines, |
|
set_of_src_and_tgt_vocab) |
|
tgt_vectors = src_vectors |
|
logger.info("\tFound {} total vectors in file".format(total_vec_count)) |
|
else: |
|
if opt.emb_file_enc is None: |
|
raise ValueError("If --emb_file_enc not provided. Please specify " |
|
"the file with encoder embeddings, or pass in " |
|
"--emb_file_both") |
|
if opt.emb_file_dec is None: |
|
raise ValueError("If --emb_file_dec not provided. Please specify " |
|
"the file with encoder embeddings, or pass in " |
|
"--emb_file_both") |
|
logger.info("Reading encoder embeddings from {}".format( |
|
opt.emb_file_enc)) |
|
src_vectors, total_vec_count = read_embeddings( |
|
opt.emb_file_enc, skip_lines, |
|
filter_set=enc_vocab.stoi |
|
) |
|
logger.info("\tFound {} total vectors in file.".format( |
|
total_vec_count)) |
|
logger.info("Reading decoder embeddings from {}".format( |
|
opt.emb_file_dec)) |
|
tgt_vectors, total_vec_count = read_embeddings( |
|
opt.emb_file_dec, skip_lines, |
|
filter_set=dec_vocab.stoi |
|
) |
|
logger.info("\tFound {} total vectors in file".format(total_vec_count)) |
|
logger.info("After filtering to vectors in vocab:") |
|
logger.info("\t* enc: %d match, %d missing, (%.2f%%)" |
|
% calc_vocab_load_stats(enc_vocab, src_vectors)) |
|
logger.info("\t* dec: %d match, %d missing, (%.2f%%)" |
|
% calc_vocab_load_stats(dec_vocab, tgt_vectors)) |
|
|
|
|
|
enc_output_file = opt.output_file + ".enc.pt" |
|
dec_output_file = opt.output_file + ".dec.pt" |
|
logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s" |
|
% (enc_output_file, dec_output_file)) |
|
torch.save( |
|
convert_to_torch_tensor(src_vectors, enc_vocab), |
|
enc_output_file |
|
) |
|
torch.save( |
|
convert_to_torch_tensor(tgt_vectors, dec_vocab), |
|
dec_output_file |
|
) |
|
logger.info("\nDone.") |
|
|
|
|
|
if __name__ == "__main__": |
|
init_logger('embeddings_to_torch.log') |
|
main() |
|
|