File size: 6,120 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import six
import argparse
import torch
from onmt.utils.logging import init_logger, logger
def get_vocabs(dict_path):
fields = torch.load(dict_path)
vocs = []
for side in ['src', 'tgt']:
try:
vocab = fields[side].base_field.vocab
except AttributeError:
vocab = fields[side].vocab
vocs.append(vocab)
enc_vocab, dec_vocab = vocs
logger.info("From: %s" % dict_path)
logger.info("\t* source vocab: %d words" % len(enc_vocab))
logger.info("\t* target vocab: %d words" % len(dec_vocab))
return enc_vocab, dec_vocab
def read_embeddings(file_enc, skip_lines=0, filter_set=None):
embs = dict()
total_vectors_in_file = 0
with open(file_enc, 'rb') as f:
for i, line in enumerate(f):
if i < skip_lines:
continue
if not line:
break
if len(line) == 0:
# is this reachable?
continue
l_split = line.decode('utf8').strip().split(' ')
if len(l_split) == 2:
continue
total_vectors_in_file += 1
if filter_set is not None and l_split[0] not in filter_set:
continue
embs[l_split[0]] = [float(em) for em in l_split[1:]]
return embs, total_vectors_in_file
def convert_to_torch_tensor(word_to_float_list_dict, vocab):
dim = len(six.next(six.itervalues(word_to_float_list_dict)))
tensor = torch.zeros((len(vocab), dim))
for word, values in word_to_float_list_dict.items():
tensor[vocab.stoi[word]] = torch.Tensor(values)
return tensor
def calc_vocab_load_stats(vocab, loaded_embed_dict):
matching_count = len(
set(vocab.stoi.keys()) & set(loaded_embed_dict.keys()))
missing_count = len(vocab) - matching_count
percent_matching = matching_count / len(vocab) * 100
return matching_count, missing_count, percent_matching
def main():
parser = argparse.ArgumentParser(description='embeddings_to_torch.py')
parser.add_argument('-emb_file_both', required=False,
help="loads Embeddings for both source and target "
"from this file.")
parser.add_argument('-emb_file_enc', required=False,
help="source Embeddings from this file")
parser.add_argument('-emb_file_dec', required=False,
help="target Embeddings from this file")
parser.add_argument('-output_file', required=True,
help="Output file for the prepared data")
parser.add_argument('-dict_file', required=True,
help="Dictionary file")
parser.add_argument('-verbose', action="store_true", default=False)
parser.add_argument('-skip_lines', type=int, default=0,
help="Skip first lines of the embedding file")
parser.add_argument('-type', choices=["GloVe", "word2vec"],
default="GloVe")
opt = parser.parse_args()
enc_vocab, dec_vocab = get_vocabs(opt.dict_file)
# Read in embeddings
skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines
if opt.emb_file_both is not None:
if opt.emb_file_enc is not None:
raise ValueError("If --emb_file_both is passed in, you should not"
"set --emb_file_enc.")
if opt.emb_file_dec is not None:
raise ValueError("If --emb_file_both is passed in, you should not"
"set --emb_file_dec.")
set_of_src_and_tgt_vocab = \
set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys())
logger.info("Reading encoder and decoder embeddings from {}".format(
opt.emb_file_both))
src_vectors, total_vec_count = \
read_embeddings(opt.emb_file_both, skip_lines,
set_of_src_and_tgt_vocab)
tgt_vectors = src_vectors
logger.info("\tFound {} total vectors in file".format(total_vec_count))
else:
if opt.emb_file_enc is None:
raise ValueError("If --emb_file_enc not provided. Please specify "
"the file with encoder embeddings, or pass in "
"--emb_file_both")
if opt.emb_file_dec is None:
raise ValueError("If --emb_file_dec not provided. Please specify "
"the file with encoder embeddings, or pass in "
"--emb_file_both")
logger.info("Reading encoder embeddings from {}".format(
opt.emb_file_enc))
src_vectors, total_vec_count = read_embeddings(
opt.emb_file_enc, skip_lines,
filter_set=enc_vocab.stoi
)
logger.info("\tFound {} total vectors in file.".format(
total_vec_count))
logger.info("Reading decoder embeddings from {}".format(
opt.emb_file_dec))
tgt_vectors, total_vec_count = read_embeddings(
opt.emb_file_dec, skip_lines,
filter_set=dec_vocab.stoi
)
logger.info("\tFound {} total vectors in file".format(total_vec_count))
logger.info("After filtering to vectors in vocab:")
logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
% calc_vocab_load_stats(enc_vocab, src_vectors))
logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
% calc_vocab_load_stats(dec_vocab, tgt_vectors))
# Write to file
enc_output_file = opt.output_file + ".enc.pt"
dec_output_file = opt.output_file + ".dec.pt"
logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s"
% (enc_output_file, dec_output_file))
torch.save(
convert_to_torch_tensor(src_vectors, enc_vocab),
enc_output_file
)
torch.save(
convert_to_torch_tensor(tgt_vectors, dec_vocab),
dec_output_file
)
logger.info("\nDone.")
if __name__ == "__main__":
init_logger('embeddings_to_torch.log')
main()
|