|
|
|
|
|
|
|
|
|
|
|
import io |
|
import os |
|
import string |
|
import tempfile |
|
import unittest |
|
|
|
import torch |
|
from fairseq import tokenizer |
|
from fairseq.data import Dictionary |
|
|
|
|
|
class TestDictionary(unittest.TestCase): |
|
def test_finalize(self): |
|
txt = [ |
|
"A B C D", |
|
"B C D", |
|
"C D", |
|
"D", |
|
] |
|
ref_ids1 = list( |
|
map( |
|
torch.IntTensor, |
|
[ |
|
[4, 5, 6, 7, 2], |
|
[5, 6, 7, 2], |
|
[6, 7, 2], |
|
[7, 2], |
|
], |
|
) |
|
) |
|
ref_ids2 = list( |
|
map( |
|
torch.IntTensor, |
|
[ |
|
[7, 6, 5, 4, 2], |
|
[6, 5, 4, 2], |
|
[5, 4, 2], |
|
[4, 2], |
|
], |
|
) |
|
) |
|
|
|
|
|
d = Dictionary() |
|
for line in txt: |
|
d.encode_line(line, add_if_not_exist=True) |
|
|
|
def get_ids(dictionary): |
|
ids = [] |
|
for line in txt: |
|
ids.append(dictionary.encode_line(line, add_if_not_exist=False)) |
|
return ids |
|
|
|
def assertMatch(ids, ref_ids): |
|
for toks, ref_toks in zip(ids, ref_ids): |
|
self.assertEqual(toks.size(), ref_toks.size()) |
|
self.assertEqual(0, (toks != ref_toks).sum().item()) |
|
|
|
ids = get_ids(d) |
|
assertMatch(ids, ref_ids1) |
|
|
|
|
|
d.finalize() |
|
finalized_ids = get_ids(d) |
|
assertMatch(finalized_ids, ref_ids2) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w") as tmp_dict: |
|
d.save(tmp_dict.name) |
|
d = Dictionary.load(tmp_dict.name) |
|
reload_ids = get_ids(d) |
|
assertMatch(reload_ids, ref_ids2) |
|
assertMatch(finalized_ids, reload_ids) |
|
|
|
def test_overwrite(self): |
|
|
|
dict_file = io.StringIO( |
|
"<unk> 999 #fairseq:overwrite\n" |
|
"<s> 999 #fairseq:overwrite\n" |
|
"</s> 999 #fairseq:overwrite\n" |
|
", 999\n" |
|
"▁de 999\n" |
|
) |
|
d = Dictionary() |
|
d.add_from_file(dict_file) |
|
self.assertEqual(d.index("<pad>"), 1) |
|
self.assertEqual(d.index("foo"), 3) |
|
self.assertEqual(d.index("<unk>"), 4) |
|
self.assertEqual(d.index("<s>"), 5) |
|
self.assertEqual(d.index("</s>"), 6) |
|
self.assertEqual(d.index(","), 7) |
|
self.assertEqual(d.index("▁de"), 8) |
|
|
|
def test_no_overwrite(self): |
|
|
|
dict_file = io.StringIO( |
|
"<unk> 999\n" "<s> 999\n" "</s> 999\n" ", 999\n" "▁de 999\n" |
|
) |
|
d = Dictionary() |
|
with self.assertRaisesRegex(RuntimeError, "Duplicate"): |
|
d.add_from_file(dict_file) |
|
|
|
def test_space(self): |
|
|
|
dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n") |
|
d = Dictionary() |
|
d.add_from_file(dict_file) |
|
self.assertEqual(d.index(" "), 4) |
|
self.assertEqual(d.index("a"), 5) |
|
self.assertEqual(d.index("b"), 6) |
|
|
|
def test_add_file_to_dict(self): |
|
counts = {} |
|
num_lines = 100 |
|
per_line = 10 |
|
with tempfile.TemporaryDirectory("test_sampling") as data_dir: |
|
filename = os.path.join(data_dir, "dummy.txt") |
|
with open(filename, "w", encoding="utf-8") as data: |
|
for c in string.ascii_letters: |
|
line = f"{c} " * per_line |
|
for _ in range(num_lines): |
|
data.write(f"{line}\n") |
|
counts[c] = per_line * num_lines |
|
per_line += 5 |
|
|
|
dict = Dictionary() |
|
Dictionary.add_file_to_dictionary( |
|
filename, dict, tokenizer.tokenize_line, 10 |
|
) |
|
dict.finalize(threshold=0, nwords=-1, padding_factor=8) |
|
|
|
for c in string.ascii_letters: |
|
count = dict.get_count(dict.index(c)) |
|
self.assertEqual( |
|
counts[c], count, f"{c} count is {count} but should be {counts[c]}" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|