File size: 3,137 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Module for build dynamic fields."""
from collections import Counter, defaultdict
import torch
from onmt.utils.logging import logger
from onmt.utils.misc import check_path
from onmt.inputters.inputter import get_fields, _load_vocab, \
    _build_fields_vocab


def _get_dynamic_fields(opts):
    # NOTE: not support tgt feats yet
    tgt_feats = None
    with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
    fields = get_fields('text', opts.src_feats_vocab, tgt_feats,
                        dynamic_dict=opts.copy_attn,
                        src_truncate=opts.src_seq_length_trunc,
                        tgt_truncate=opts.tgt_seq_length_trunc,
                        with_align=with_align,
                        data_task=opts.data_task)

    return fields


def build_dynamic_fields(opts, src_specials=None, tgt_specials=None):
    """Build fields for dynamic, including load & build vocab."""
    fields = _get_dynamic_fields(opts)

    counters = defaultdict(Counter)
    logger.info("Loading vocab from text file...")

    _src_vocab, _src_vocab_size = _load_vocab(
        opts.src_vocab, 'src', counters,
        min_freq=opts.src_words_min_frequency)

    if opts.src_feats_vocab:
        for feat_name, filepath in opts.src_feats_vocab.items():
            _, _ = _load_vocab(
                filepath, feat_name, counters,
                min_freq=0)

    if opts.tgt_vocab:
        _tgt_vocab, _tgt_vocab_size = _load_vocab(
            opts.tgt_vocab, 'tgt', counters,
            min_freq=opts.tgt_words_min_frequency)
    elif opts.share_vocab:
        logger.info("Sharing src vocab to tgt...")
        counters['tgt'] = counters['src']
    else:
        raise ValueError("-tgt_vocab should be specified if not share_vocab.")

    logger.info("Building fields with vocab in counters...")
    fields = _build_fields_vocab(
        fields, counters, 'text', opts.share_vocab,
        opts.vocab_size_multiple,
        opts.src_vocab_size, opts.src_words_min_frequency,
        opts.tgt_vocab_size, opts.tgt_words_min_frequency,
        src_specials=src_specials, tgt_specials=tgt_specials)

    return fields


def get_vocabs(fields):
    """Get a dict contain src & tgt vocab extracted from fields."""
    src_vocab = fields['src'].base_field.vocab
    tgt_vocab = fields['tgt'].base_field.vocab
    vocabs = {'src': src_vocab, 'tgt': tgt_vocab}
    return vocabs


def save_fields(fields, save_data, overwrite=True):
    """Dump `fields` object."""
    fields_path = "{}.vocab.pt".format(save_data)
    check_path(fields_path, exist_ok=overwrite, log=logger.warning)
    logger.info(f"Saving fields to {fields_path}...")
    torch.save(fields, fields_path)


def load_fields(save_data, checkpoint=None):
    """Load dumped fields object from `save_data` or `checkpoint` if any."""
    if checkpoint is not None:
        logger.info("Loading fields from checkpoint...")
        fields = checkpoint['vocab']
    else:
        fields_path = "{}.vocab.pt".format(save_data)
        logger.info(f"Loading fields from {fields_path}...")
        fields = torch.load(fields_path)
    return fields