Spaces:
Build error
Build error
__author__ = 'Taneem Jan, taneemishere.github.io' | |
import sys | |
import numpy as np | |
START_TOKEN = "<START>" | |
END_TOKEN = "<END>" | |
PLACEHOLDER = " " | |
SEPARATOR = '->' | |
class Vocabulary: | |
def __init__(self): | |
self.binary_vocabulary = {} | |
self.vocabulary = {} | |
self.token_lookup = {} | |
self.size = 0 | |
self.append(START_TOKEN) | |
self.append(END_TOKEN) | |
self.append(PLACEHOLDER) | |
def append(self, token): | |
if token not in self.vocabulary: | |
self.vocabulary[token] = self.size | |
self.token_lookup[self.size] = token | |
self.size += 1 | |
def create_binary_representation(self): | |
if sys.version_info >= (3,): | |
items = self.vocabulary.items() | |
else: | |
items = self.vocabulary.iteritems() | |
for key, value in items: | |
binary = np.zeros(self.size) | |
binary[value] = 1 | |
self.binary_vocabulary[key] = binary | |
def get_serialized_binary_representation(self): | |
if len(self.binary_vocabulary) == 0: | |
self.create_binary_representation() | |
string = "" | |
if sys.version_info >= (3,): | |
items = self.binary_vocabulary.items() | |
else: | |
items = self.binary_vocabulary.iteritems() | |
for key, value in items: | |
array_as_string = np.array2string(value, separator=',', max_line_width=self.size * self.size) | |
string += "{}{}{}\n".format(key, SEPARATOR, array_as_string[1:len(array_as_string) - 1]) | |
return string | |
def save(self, path): | |
output_file_name = "{}/words.vocab".format(path) | |
output_file = open(output_file_name, 'w') | |
output_file.write(self.get_serialized_binary_representation()) | |
output_file.close() | |
def retrieve(self, path): | |
input_file = open("{}/words.vocab".format(path), 'r') | |
buffer = "" | |
for line in input_file: | |
try: | |
separator_position = len(buffer) + line.index(SEPARATOR) | |
buffer += line | |
key = buffer[:separator_position] | |
value = buffer[separator_position + len(SEPARATOR):] | |
value = np.fromstring(value, sep=',') | |
self.binary_vocabulary[key] = value | |
self.vocabulary[key] = np.where(value == 1)[0][0] | |
self.token_lookup[np.where(value == 1)[0][0]] = key | |
buffer = "" | |
except ValueError: | |
buffer += line | |
input_file.close() | |
self.size = len(self.vocabulary) | |