|
import gradio as gr |
|
import io |
|
import numpy as np |
|
|
|
|
|
class TrieNode: |
|
def __init__(self): |
|
self.children = {} |
|
self.is_end_of_token = False |
|
class Trie: |
|
def __init__(self): |
|
self.root = TrieNode() |
|
|
|
def insert(self, token): |
|
node = self.root |
|
for char in token: |
|
if char not in node.children: |
|
node.children[char] = TrieNode() |
|
node = node.children[char] |
|
node.is_end_of_token = True |
|
|
|
def search_longest_prefix(self, text, start): |
|
node = self.root |
|
longest_match = None |
|
current_pos = start |
|
|
|
while current_pos < len(text) and text[current_pos] in node.children: |
|
node = node.children[text[current_pos]] |
|
if node.is_end_of_token: |
|
longest_match = current_pos |
|
current_pos += 1 |
|
|
|
return longest_match |
|
|
|
|
|
def load_vectors(fname): |
|
fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore') |
|
data = {} |
|
for line in fin: |
|
tokens = line.rstrip().split(' ') |
|
data[tokens[0]] = np.array(list(map(float, tokens[1:]))) |
|
del fin |
|
return data, sorted(data.keys(), key=len, reverse=True) |
|
vectors, sorted_vector = load_vectors('wiki-news-300d-1M.vec') |
|
|
|
|
|
def tokenize(text): |
|
trie = Trie() |
|
for token in sorted_vector: |
|
trie.insert(token) |
|
|
|
result = [] |
|
start = 0 |
|
|
|
while start < len(text): |
|
longest_match = trie.search_longest_prefix(text, start) |
|
if longest_match is not None: |
|
result.append(text[start:longest_match+1]) |
|
start = longest_match + 1 |
|
else: |
|
start += 1 |
|
|
|
return result |
|
|
|
|
|
def onInput(paragraph, progress = gr.Progress()): |
|
progress(0, "Tokenizing...") |
|
tokens = tokenize(paragraph) |
|
|
|
progress(0.1, "Initializing merged vector...") |
|
if not tokens: |
|
return np.zeros(300).tolist() |
|
|
|
merged_vector = np.zeros(300) |
|
|
|
|
|
totalTokens = len(tokens) |
|
for ind, token in enumerate(tokens): |
|
completion = 0.7*((ind+1)/totalTokens) |
|
progress(0.1 + completion, f"Merging {token}, Token #{tokens.index(token)+1}/{len(tokens)}") |
|
|
|
vector = vectors[token] |
|
merged_vector += vector |
|
|
|
|
|
progress(0.9, "Normalizing...") |
|
merged_vector /= len(tokens) |
|
|
|
progress(1, "Converting to list...") |
|
return merged_vector.tolist() |
|
|
|
demo = gr.Interface(fn=onInput, inputs="text", outputs="text") |
|
demo.launch() |