antonlabate
ver 1.3
d758c99
import abc
import functools
import os
import time
import bpemb
#import corenlp
import torch
import torchtext
#from seq2struct.resources import corenlp
from seq2struct.utils import registry
class Embedder(metaclass=abc.ABCMeta):
@abc.abstractmethod
def tokenize(self, sentence):
'''Given a string, return a list of tokens suitable for lookup.'''
pass
@abc.abstractmethod
def untokenize(self, tokens):
'''Undo tokenize.'''
pass
@abc.abstractmethod
def lookup(self, token):
'''Given a token, return a vector embedding if token is in vocabulary.
If token is not in the vocabulary, then return None.'''
pass
@abc.abstractmethod
def contains(self, token):
pass
@abc.abstractmethod
def to(self, device):
'''Transfer the pretrained embeddings to the given device.'''
pass
@registry.register('word_emb', 'glove')
class GloVe(Embedder):
def __init__(self, kind, lemmatize=False):
cache = os.path.join(os.environ.get('CACHE_DIR', os.getcwd()), '.vector_cache')
self.glove = torchtext.vocab.GloVe(name=kind, cache=cache)
self.dim = self.glove.dim
self.vectors = self.glove.vectors
self.lemmatize = lemmatize
self.corenlp_annotators = ['tokenize', 'ssplit']
if lemmatize:
self.corenlp_annotators.append('lemma')
@functools.lru_cache(maxsize=1024)
def tokenize(self, text):
ann = corenlp.annotate(text, self.corenlp_annotators)
if self.lemmatize:
return [tok.lemma.lower() for sent in ann.sentence for tok in sent.token]
else:
return [tok.word.lower() for sent in ann.sentence for tok in sent.token]
@functools.lru_cache(maxsize=1024)
def tokenize_for_copying(self, text):
ann = corenlp.annotate(text, self.corenlp_annotators)
text_for_copying = [tok.originalText.lower() for sent in ann.sentence for tok in sent.token]
if self.lemmatize:
text = [tok.lemma.lower() for sent in ann.sentence for tok in sent.token]
else:
text = [tok.word.lower() for sent in ann.sentence for tok in sent.token]
return text, text_for_copying
def untokenize(self, tokens):
return ' '.join(tokens)
def lookup(self, token):
i = self.glove.stoi.get(token)
if i is None:
return None
return self.vectors[i]
def contains(self, token):
return token in self.glove.stoi
def to(self, device):
self.vectors = self.vectors.to(device)
@registry.register('word_emb', 'bpemb')
class BPEmb(Embedder):
def __init__(self, dim, vocab_size, lang='en'):
self.bpemb = bpemb.BPEmb(lang=lang, dim=dim, vs=vocab_size)
self.dim = dim
self.vectors = torch.from_numpy(self.bpemb.vectors)
def tokenize(self, text):
return self.bpemb.encode(text)
def untokenize(self, tokens):
return self.bpemb.decode(tokens)
def lookup(self, token):
i = self.bpemb.spm.PieceToId(token)
if i == self.bpemb.spm.unk_id():
return None
return self.vectors[i]
def contains(self, token):
return self.lookup(token) is not None
def to(self, device):
self.vectors = self.vectors.to(device)