File size: 2,583 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-

# Copyright 2018 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""TTS Interface realted modules."""

from espnet.asr.asr_utils import torch_load


try:
    import chainer
except ImportError:
    Reporter = None
else:

    class Reporter(chainer.Chain):
        """Reporter module."""

        def report(self, dicts):
            """Report values from a given dict."""
            for d in dicts:
                chainer.reporter.report(d, self)


class TTSInterface(object):
    """TTS Interface for ESPnet model implementation."""

    @staticmethod
    def add_arguments(parser):
        """Add model specific argments to parser."""
        return parser

    def __init__(self):
        """Initilize TTS module."""
        self.reporter = Reporter()

    def forward(self, *args, **kwargs):
        """Calculate TTS forward propagation.

        Returns:
            Tensor: Loss value.

        """
        raise NotImplementedError("forward method is not implemented")

    def inference(self, *args, **kwargs):
        """Generate the sequence of features given the sequences of characters.

        Returns:
            Tensor: The sequence of generated features (L, odim).
            Tensor: The sequence of stop probabilities (L,).
            Tensor: The sequence of attention weights (L, T).

        """
        raise NotImplementedError("inference method is not implemented")

    def calculate_all_attentions(self, *args, **kwargs):
        """Calculate TTS attention weights.

        Args:
            Tensor: Batch of attention weights (B, Lmax, Tmax).

        """
        raise NotImplementedError("calculate_all_attentions method is not implemented")

    def load_pretrained_model(self, model_path):
        """Load pretrained model parameters."""
        torch_load(model_path, self)

    @property
    def attention_plot_class(self):
        """Plot attention weights."""
        from espnet.asr.asr_utils import PlotAttentionReport

        return PlotAttentionReport

    @property
    def base_plot_keys(self):
        """Return base key names to plot during training.

        The keys should match what `chainer.reporter` reports.
        if you add the key `loss`,
        the reporter will report `main/loss` and `validation/main/loss` values.
        also `loss.png` will be created as a figure visulizing `main/loss`
        and `validation/main/loss` values.

        Returns:
            list[str]:  Base keys to plot during training.

        """
        return ["loss"]