NMT-LaVi / utils /data.py
hieungo1410's picture
'add'
8cb4f3b
import re, os
import nltk
from nltk.corpus import wordnet
import dill as pickle
import pandas as pd
from torchtext import data
from laonlp import tokenize
def multiple_replace(dict, text):
# Create a regular expression from the dictionary keys
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
# For each match, look-up corresponding value in dictionary
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
# get_synonym replace word with any synonym found among src
def get_synonym(word, SRC):
syns = wordnet.synsets(word)
for s in syns:
for l in s.lemmas():
if SRC.vocab.stoi[l.name()] != 0:
return SRC.vocab.stoi[l.name()]
return 0
class Tokenizer:
def __init__(self, lang=None):
if(lang is not None):
self.nlp = spacy.load(lang)
self.tokenizer_fn = self.nlp.tokenizer
else:
self.tokenizer_fn = lambda l: l.strip().split()
# def tokenize(self, sentence):
# sentence = re.sub(
# r"[\*\"β€œβ€\n\\…\+\-\/\=\(\)β€˜β€’:\[\]\|’\!;]", " ", str(sentence))
# sentence = re.sub(r"[ ]+", " ", sentence)
# sentence = re.sub(r"\!+", "!", sentence)
# sentence = re.sub(r"\,+", ",", sentence)
# sentence = re.sub(r"\?+", "?", sentence)
# sentence = sentence.lower()
# return [tok.text for tok in self.tokenizer_fn(sentence) if tok.text != " "]
def read_data(src_file, trg_file):
src_data = open(src_file).read().strip().split('\n')
trg_data = open(trg_file).read().strip().split('\n')
return src_data, trg_data
def read_file(file_dir):
f = open(file_dir, 'r')
data = f.read().strip().split('\n')
return data
def write_file(file_dir, content):
f = open(file_dir, "w")
f.write(content)
f.close()
def create_fields(src_lang, trg_lang):
#print("loading spacy tokenizers...")
#
# t_src = tokenize(src_lang)
# t_trg = tokenize(trg_lang)
# t_src_tokenizer = t_trg_tokenizer = lambda x: x.strip().split()
target_tokenizer = lambda x: x.strip().split()
TRG = data.Field(lower=True, tokenize=target_tokenizer, init_token='<sos>', eos_token='<eos>')
SRC = data.Field(lower=True, tokenize=tokenize.word_tokenize)
return SRC, TRG
def create_dataset(src_data, trg_data, max_strlen, batchsize, device, SRC, TRG, istrain=True):
print("creating dataset and iterator... ")
raw_data = {'src' : [line for line in src_data], 'trg': [line for line in trg_data]}
df = pd.DataFrame(raw_data, columns=["src", "trg"])
mask = (df['src'].str.count(' ') < max_strlen) & (df['trg'].str.count(' ') < max_strlen)
df = df.loc[mask]
df.to_csv("translate_transformer_temp.csv", index=False)
data_fields = [('src', SRC), ('trg', TRG)]
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
train_iter = MyIterator(train, batch_size=batchsize, device=device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=istrain, shuffle=True)
os.remove('translate_transformer_temp.csv')
if istrain:
SRC.build_vocab(train)
TRG.build_vocab(train)
return train_iter
class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src))
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
def generate_language_token(lang: str):
return '<{}>'.format(lang.strip())