NMTKD / translation /OpenNMT-py /onmt /tests /test_text_dataset.py
sakharamg's picture
Uploading all files
158b61b
import unittest
from onmt.inputters.text_dataset import TextMultiField, TextDataReader
import itertools
import os
from copy import deepcopy
from torchtext.data import Field
from onmt.tests.utils_for_tests import product_dict
class TestTextMultiField(unittest.TestCase):
INIT_CASES = list(product_dict(
base_name=["base_field", "zbase_field"],
base_field=[Field],
feats_fields=[
[],
[("a", Field)],
[("r", Field), ("b", Field)]]))
PARAMS = list(product_dict(
include_lengths=[False, True]))
@classmethod
def initialize_case(cls, init_case, params):
# initialize fields at the top of each unit test to prevent
# any undesired stateful effects
case = deepcopy(init_case)
case["base_field"] = case["base_field"](
include_lengths=params["include_lengths"])
for i, (n, f_cls) in enumerate(case["feats_fields"]):
case["feats_fields"][i] = (n, f_cls(sequential=True))
return case
def test_process_shape(self):
dummy_input_bs_1 = [[
["this", "is", "for", "the", "unittest"],
["NOUN", "VERB", "PREP", "ART", "NOUN"],
["", "", "", "", "MODULE"]]]
dummy_input_bs_5 = [
[["this", "is", "for", "the", "unittest"],
["NOUN", "VERB", "PREP", "ART", "NOUN"],
["", "", "", "", "MODULE"]],
[["batch", "2"],
["NOUN", "NUM"],
["", ""]],
[["batch", "3", "is", "the", "longest", "batch"],
["NOUN", "NUM", "VERB", "ART", "ADJ", "NOUN"],
["", "", "", "", "", ""]],
[["fourth", "batch"],
["ORD", "NOUN"],
["", ""]],
[["and", "another", "one"],
["CONJ", "?", "NUM"],
["", "", ""]]]
for bs, max_len, dummy_input in [
(1, 5, dummy_input_bs_1), (5, 6, dummy_input_bs_5)]:
for init_case, params in itertools.product(
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
fields = [init_case["base_field"]] \
+ [f for _, f in init_case["feats_fields"]]
nfields = len(fields)
for i, f in enumerate(fields):
all_sents = [b[i] for b in dummy_input]
f.build_vocab(all_sents)
inp_only_desired_fields = [b[:nfields] for b in dummy_input]
data = mf.process(inp_only_desired_fields)
if params["include_lengths"]:
data, lengths = data
self.assertEqual(lengths.shape, (bs,))
expected_shape = (max_len, bs, nfields)
self.assertEqual(data.shape, expected_shape)
def test_preprocess_shape(self):
for init_case, params in itertools.product(
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
sample_str = {
"base_field": "dummy input here .",
"a": "A A B D",
"r": "C C C C",
"b": "D F E D",
"zbase_field": "another dummy input ."
}
proc = mf.preprocess(sample_str)
self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1)
def test_base_field(self):
for init_case, params in itertools.product(
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
self.assertIs(mf.base_field, init_case["base_field"])
def test_correct_n_fields(self):
for init_case, params in itertools.product(
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
self.assertEqual(len(mf.fields),
len(init_case["feats_fields"]) + 1)
def test_fields_order_correct(self):
for init_case, params in itertools.product(
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
fnames = [name for name, _ in init_case["feats_fields"]]
correct_order = [init_case["base_name"]] + list(sorted(fnames))
self.assertEqual([name for name, _ in mf.fields], correct_order)
def test_getitem_0_returns_correct_field(self):
for init_case, params in itertools.product(
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
self.assertEqual(mf[0][0], init_case["base_name"])
self.assertIs(mf[0][1], init_case["base_field"])
def test_getitem_nonzero_returns_correct_field(self):
for init_case, params in itertools.product(
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
fnames = [name for name, _ in init_case["feats_fields"]]
if len(fnames) > 0:
ordered_names = list(sorted(fnames))
name2field = dict(init_case["feats_fields"])
for i, name in enumerate(ordered_names, 1):
expected_field = name2field[name]
self.assertIs(mf[i][1], expected_field)
def test_getitem_has_correct_number_of_indexes(self):
for init_case, params in itertools.product(
self.INIT_CASES, self.PARAMS):
init_case = self.initialize_case(init_case, params)
mf = TextMultiField(**init_case)
nfields = len(init_case["feats_fields"]) + 1
with self.assertRaises(IndexError):
mf[nfields]
class TestTextDataReader(unittest.TestCase):
def test_read(self):
strings = [
"hello world".encode("utf-8"),
"this's a string with punctuation .".encode("utf-8"),
"ThIs Is A sTrInG wItH oDD CapitALIZAtion".encode("utf-8")
]
rdr = TextDataReader()
for i, ex in enumerate(rdr.read(strings, "src")):
self.assertEqual(ex["src"], {"src": strings[i].decode("utf-8")})
class TestTextDataReaderFromFS(unittest.TestCase):
# this test touches the file system, so it could be considered an
# integration test
STRINGS = [
"hello world\n".encode("utf-8"),
"this's a string with punctuation . \n".encode("utf-8"),
"ThIs Is A sTrInG wItH oDD CapitALIZAtion\n".encode("utf-8")
]
FILE_NAME = "test_strings.txt"
@classmethod
def setUpClass(cls):
# write utf-8 bytes
with open(cls.FILE_NAME, "wb") as f:
for str_ in cls.STRINGS:
f.write(str_)
@classmethod
def tearDownClass(cls):
os.remove(cls.FILE_NAME)
def test_read(self):
rdr = TextDataReader()
for i, ex in enumerate(rdr.read(self.FILE_NAME, "src")):
self.assertEqual(
ex["src"], {"src": self.STRINGS[i].decode("utf-8")})
class TestTextDataReaderWithFeatures(unittest.TestCase):
def test_read(self):
strings = [
"hello world".encode("utf-8"),
"this's a string with punctuation .".encode("utf-8"),
"ThIs Is A sTrInG wItH oDD CapitALIZAtion".encode("utf-8")
]
features = {
"feat_0": [
"A A".encode("utf-8"),
"A A B B C".encode("utf-8"),
"A A D D E E".encode("utf-8")
]
}
rdr = TextDataReader()
for i, ex in enumerate(rdr.read(strings, "src", features)):
self.assertEqual(
ex["src"],
{"src": strings[i].decode("utf-8"),
"feat_0": features["feat_0"][i].decode("utf-8")})