File size: 18,299 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
"""Module that contain shard utils for dynamic data."""
import os
from onmt.utils.logging import logger
from onmt.constants import CorpusName
from onmt.transforms import TransformPipe
from onmt.inputters.dataset_base import _dynamic_dict
from torchtext.data import Dataset as TorchtextDataset, \
    Example as TorchtextExample

from collections import Counter, defaultdict
from contextlib import contextmanager

import multiprocessing as mp


@contextmanager
def exfile_open(filename, *args, **kwargs):
    """Extended file opener enables open(filename=None).

    This context manager enables open(filename=None) as well as regular file.
    filename None will produce endlessly None for each iterate,
    while filename with valid path will produce lines as usual.

    Args:
        filename (str|None): a valid file path or None;
        *args: args relate to open file using codecs;
        **kwargs: kwargs relate to open file using codecs.

    Yields:
        `None` repeatly if filename==None,
        else yield from file specified in `filename`.
    """
    if filename is None:
        from itertools import repeat
        _file = repeat(None)
    else:
        import codecs
        _file = codecs.open(filename, *args, **kwargs)
    yield _file
    if filename is not None and _file:
        _file.close()


class DatasetAdapter(object):
    """Adapte a buckets of tuples into examples of a torchtext Dataset."""

    valid_field_name = (
        'src', 'tgt', 'indices', 'src_map', 'src_ex_vocab', 'alignment',
        'align')

    def __init__(self, fields, is_train):
        self.fields_dict = self._valid_fields(fields)
        self.is_train = is_train

    @classmethod
    def _valid_fields(cls, fields):
        """Return valid fields in dict format."""
        return {
            f_k: f_v for f_k, f_v in fields.items()
            if f_k in cls.valid_field_name
        }

    @staticmethod
    def _process(item, is_train):
        """Return valid transformed example from `item`."""
        example, transform, cid = item
        # this is a hack: appears quicker to apply it here
        # than in the ParallelCorpusIterator
        maybe_example = transform.apply(
            example, is_train=is_train, corpus_name=cid)
        if maybe_example is None:
            return None

        maybe_example['src'] = {"src": ' '.join(maybe_example['src'])}

        # Make features part of src as in TextMultiField
        # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}}
        if 'src_feats' in maybe_example:
            for feat_name, feat_value in maybe_example['src_feats'].items():
                maybe_example['src'][feat_name] = ' '.join(feat_value)
            del maybe_example["src_feats"]

        maybe_example['tgt'] = {"tgt": ' '.join(maybe_example['tgt'])}
        if 'align' in maybe_example:
            maybe_example['align'] = ' '.join(maybe_example['align'])

        return maybe_example

    def _maybe_add_dynamic_dict(self, example, fields):
        """maybe update `example` with dynamic_dict related fields."""
        if 'src_map' in fields and 'alignment' in fields:
            example = _dynamic_dict(
                example,
                fields['src'].base_field,
                fields['tgt'].base_field)
        return example

    def _to_examples(self, bucket, is_train=False):
        examples = []
        for item in bucket:
            maybe_example = self._process(item, is_train=is_train)
            if maybe_example is not None:
                example = self._maybe_add_dynamic_dict(
                    maybe_example, self.fields_dict)
                ex_fields = {k: [(k, v)] for k, v in self.fields_dict.items()
                             if k in example}
                ex = TorchtextExample.fromdict(example, ex_fields)
                examples.append(ex)
        return examples

    def __call__(self, bucket):
        examples = self._to_examples(bucket, is_train=self.is_train)
        dataset = TorchtextDataset(examples, self.fields_dict)
        return dataset


class ParallelCorpus(object):
    """A parallel corpus file pair that can be loaded to iterate."""

    def __init__(self, name, src, tgt, align=None, src_feats=None):
        """Initialize src & tgt side file path."""
        self.id = name
        self.src = src
        self.tgt = tgt
        self.align = align
        self.src_feats = src_feats

    def load(self, offset=0, stride=1):
        """
        Load file and iterate by lines.
        `offset` and `stride` allow to iterate only on every
        `stride` example, starting from `offset`.
        """
        if self.src_feats:
            features_names = []
            features_files = []
            for feat_name, feat_path in self.src_feats.items():
                features_names.append(feat_name)
                features_files.append(open(feat_path, mode='rb'))
        else:
            features_files = []
        with exfile_open(self.src, mode='rb') as fs,\
                exfile_open(self.tgt, mode='rb') as ft,\
                exfile_open(self.align, mode='rb') as fa:
            for i, (sline, tline, align, *features) in \
                    enumerate(zip(fs, ft, fa, *features_files)):
                if (i % stride) == offset:
                    sline = sline.decode('utf-8')
                    tline = tline.decode('utf-8')
                    # 'src_original' and 'tgt_original' store the
                    # original line before tokenization. These
                    # fields are used later on in the feature
                    # transforms.
                    example = {
                        'src': sline,
                        'tgt': tline,
                        'src_original': sline,
                        'tgt_original': tline
                    }
                    if align is not None:
                        example['align'] = align.decode('utf-8')
                    if features:
                        example["src_feats"] = dict()
                        for j, feat in enumerate(features):
                            example["src_feats"][features_names[j]] = \
                                feat.decode("utf-8")
                    yield example
        for f in features_files:
            f.close()

    def __str__(self):
        cls_name = type(self).__name__
        return '{}({}, {}, align={}, src_feats={})'.format(
            cls_name, self.src, self.tgt, self.align, self.src_feats)


def get_corpora(opts, is_train=False):
    corpora_dict = {}
    if is_train:
        for corpus_id, corpus_dict in opts.data.items():
            if corpus_id != CorpusName.VALID:
                corpora_dict[corpus_id] = ParallelCorpus(
                    corpus_id,
                    corpus_dict["path_src"],
                    corpus_dict["path_tgt"],
                    corpus_dict["path_align"],
                    corpus_dict["src_feats"])
    else:
        if CorpusName.VALID in opts.data.keys():
            corpora_dict[CorpusName.VALID] = ParallelCorpus(
                CorpusName.VALID,
                opts.data[CorpusName.VALID]["path_src"],
                opts.data[CorpusName.VALID]["path_tgt"],
                opts.data[CorpusName.VALID]["path_align"],
                opts.data[CorpusName.VALID]["src_feats"])
        else:
            return None
    return corpora_dict


class ParallelCorpusIterator(object):
    """An iterator dedicate for ParallelCorpus.

    Args:
        corpus (ParallelCorpus): corpus to iterate;
        transform (TransformPipe): transforms to be applied to corpus;
        skip_empty_level (str): security level when encouter empty line;
        stride (int): iterate corpus with this line stride;
        offset (int): iterate corpus with this line offset.
    """

    def __init__(self, corpus, transform,
                 skip_empty_level='warning', stride=1, offset=0):
        self.cid = corpus.id
        self.corpus = corpus
        self.transform = transform
        if skip_empty_level not in ['silent', 'warning', 'error']:
            raise ValueError(
                f"Invalid argument skip_empty_level={skip_empty_level}")
        self.skip_empty_level = skip_empty_level
        self.stride = stride
        self.offset = offset

    def _tokenize(self, stream):
        for example in stream:
            example['src'] = example['src'].strip('\n').split()
            example['tgt'] = example['tgt'].strip('\n').split()
            example['src_original'] = \
                example['src_original'].strip("\n").split()
            example['tgt_original'] = \
                example['tgt_original'].strip("\n").split()
            if 'align' in example:
                example['align'] = example['align'].strip('\n').split()
            if 'src_feats' in example:
                for k in example['src_feats'].keys():
                    example['src_feats'][k] = \
                        example['src_feats'][k].strip('\n').split()
            yield example

    def _transform(self, stream):
        for example in stream:
            # NOTE: moved to DatasetAdapter._process method in iterator.py
            # item = self.transform.apply(
            # example, is_train=self.infinitely, corpus_name=self.cid)
            item = (example, self.transform, self.cid)
            if item is not None:
                yield item
        report_msg = self.transform.stats()
        if report_msg != '':
            logger.info(
                "* Transform statistics for {}({:.2f}%):\n{}\n".format(
                    self.cid, 100/self.stride, report_msg
                )
            )

    def _add_index(self, stream):
        for i, item in enumerate(stream):
            example = item[0]
            line_number = i * self.stride + self.offset
            example['indices'] = line_number
            if (len(example['src']) == 0 or len(example['tgt']) == 0 or
                    ('align' in example and example['align'] == 0)):
                # empty example: skip
                empty_msg = f"Empty line exists in {self.cid}#{line_number}."
                if self.skip_empty_level == 'error':
                    raise IOError(empty_msg)
                elif self.skip_empty_level == 'warning':
                    logger.warning(empty_msg)
                continue
            yield item

    def __iter__(self):
        corpus_stream = self.corpus.load(
            stride=self.stride, offset=self.offset
        )
        tokenized_corpus = self._tokenize(corpus_stream)
        transformed_corpus = self._transform(tokenized_corpus)
        indexed_corpus = self._add_index(transformed_corpus)
        yield from indexed_corpus


def build_corpora_iters(corpora, transforms, corpora_info,
                        skip_empty_level='warning', stride=1, offset=0):
    """Return `ParallelCorpusIterator` for all corpora defined in opts."""
    corpora_iters = dict()
    for c_id, corpus in corpora.items():
        transform_names = corpora_info[c_id].get('transforms', [])
        corpus_transform = [
            transforms[name] for name in transform_names if name in transforms
        ]
        transform_pipe = TransformPipe.build_from(corpus_transform)
        logger.info(f"{c_id}'s transforms: {str(transform_pipe)}")
        corpus_iter = ParallelCorpusIterator(
            corpus, transform_pipe,
            skip_empty_level=skip_empty_level, stride=stride, offset=offset)
        corpora_iters[c_id] = corpus_iter
    return corpora_iters


def write_files_from_queues(sample_path, queues):
    """
    Standalone process that reads data from
    queues in order and write to sample files.
    """
    os.makedirs(sample_path, exist_ok=True)
    for c_name in queues.keys():
        dest_base = os.path.join(
            sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE))
        with open(dest_base + ".src", 'w', encoding="utf-8") as f_src,\
                open(dest_base + ".tgt", 'w', encoding="utf-8") as f_tgt:
            while True:
                _next = False
                for q in queues[c_name]:
                    item = q.get()
                    if item == "blank":
                        continue
                    if item == "break":
                        _next = True
                        break
                    _, src_line, tgt_line = item
                    f_src.write(src_line + '\n')
                    f_tgt.write(tgt_line + '\n')
                if _next:
                    break


# Just for debugging purposes
# It appends features to subwords when dumping to file
def append_features_to_example(example, features):
    ex_toks = example.split(' ')
    feat_toks = features.split(' ')
    toks = [f"{subword}{feat}" for subword, feat in
            zip(ex_toks, feat_toks)]
    return " ".join(toks)


def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
    """Build vocab on (strided) subpart of the data."""
    sub_counter_src = Counter()
    sub_counter_tgt = Counter()
    sub_counter_src_feats = defaultdict(Counter)
    datasets_iterables = build_corpora_iters(
        corpora, transforms, opts.data,
        skip_empty_level=opts.skip_empty_level,
        stride=stride, offset=offset)
    for c_name, c_iter in datasets_iterables.items():
        for i, item in enumerate(c_iter):
            maybe_example = DatasetAdapter._process(item, is_train=True)
            if maybe_example is None:
                if opts.dump_samples:
                    build_sub_vocab.queues[c_name][offset].put("blank")
                continue
            src_line, tgt_line = (maybe_example['src']['src'],
                                  maybe_example['tgt']['tgt'])
            src_line_pretty = src_line
            for feat_name, feat_line in maybe_example["src"].items():
                if feat_name not in ["src", "src_original"]:
                    sub_counter_src_feats[feat_name].update(
                        feat_line.split(' '))
                    if opts.dump_samples:
                        src_line_pretty = append_features_to_example(
                            src_line_pretty, feat_line)
            sub_counter_src.update(src_line.split(' '))
            sub_counter_tgt.update(tgt_line.split(' '))
            if opts.dump_samples:
                build_sub_vocab.queues[c_name][offset].put(
                    (i, src_line_pretty, tgt_line))
            if n_sample > 0 and ((i+1) * stride + offset) >= n_sample:
                if opts.dump_samples:
                    build_sub_vocab.queues[c_name][offset].put("break")
                break
        if opts.dump_samples:
            build_sub_vocab.queues[c_name][offset].put("break")
    return sub_counter_src, sub_counter_tgt, sub_counter_src_feats


def init_pool(queues):
    """Add the queues as attribute of the pooled function."""
    build_sub_vocab.queues = queues


def build_vocab(opts, transforms, n_sample=3):
    """Build vocabulary from data."""

    if n_sample == -1:
        logger.info(f"n_sample={n_sample}: Build vocab on full datasets.")
    elif n_sample > 0:
        logger.info(f"Build vocab on {n_sample} transformed examples/corpus.")
    else:
        raise ValueError(f"n_sample should > 0 or == -1, get {n_sample}.")

    if opts.dump_samples:
        logger.info("The samples on which the vocab is built will be "
                    "dumped to disk. It may slow down the process.")
    corpora = get_corpora(opts, is_train=True)
    counter_src = Counter()
    counter_tgt = Counter()
    counter_src_feats = defaultdict(Counter)
    from functools import partial
    queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
                       for i in range(opts.num_threads)]
              for c_name in corpora.keys()}
    sample_path = os.path.join(
        os.path.dirname(opts.save_data), CorpusName.SAMPLE)
    if opts.dump_samples:
        write_process = mp.Process(
            target=write_files_from_queues,
            args=(sample_path, queues),
            daemon=True)
        write_process.start()
    with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
        func = partial(
            build_sub_vocab, corpora, transforms,
            opts, n_sample, opts.num_threads)
        for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
                func, range(0, opts.num_threads)):
            counter_src.update(sub_counter_src)
            counter_tgt.update(sub_counter_tgt)
            counter_src_feats.update(sub_counter_src_feats)
    if opts.dump_samples:
        write_process.join()
    return counter_src, counter_tgt, counter_src_feats


def save_transformed_sample(opts, transforms, n_sample=3):
    """Save transformed data sample as specified in opts."""

    if n_sample == -1:
        logger.info(f"n_sample={n_sample}: Save full transformed corpus.")
    elif n_sample == 0:
        logger.info(f"n_sample={n_sample}: no sample will be saved.")
        return
    elif n_sample > 0:
        logger.info(f"Save {n_sample} transformed example/corpus.")
    else:
        raise ValueError(f"n_sample should >= -1, get {n_sample}.")

    corpora = get_corpora(opts, is_train=True)
    datasets_iterables = build_corpora_iters(
        corpora, transforms, opts.data,
        skip_empty_level=opts.skip_empty_level)
    sample_path = os.path.join(
        os.path.dirname(opts.save_data), CorpusName.SAMPLE)
    os.makedirs(sample_path, exist_ok=True)
    for c_name, c_iter in datasets_iterables.items():
        dest_base = os.path.join(
            sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE))
        with open(dest_base + ".src", 'w', encoding="utf-8") as f_src,\
                open(dest_base + ".tgt", 'w', encoding="utf-8") as f_tgt:
            for i, item in enumerate(c_iter):
                maybe_example = DatasetAdapter._process(item, is_train=True)
                if maybe_example is None:
                    continue
                src_line, tgt_line = (maybe_example['src']['src'],
                                      maybe_example['tgt']['tgt'])
                f_src.write(src_line + '\n')
                f_tgt.write(tgt_line + '\n')
                if n_sample > 0 and i >= n_sample:
                    break