File size: 2,583 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# -*- coding: utf-8 -*-
# Copyright 2018 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""TTS Interface realted modules."""
from espnet.asr.asr_utils import torch_load
try:
import chainer
except ImportError:
Reporter = None
else:
class Reporter(chainer.Chain):
"""Reporter module."""
def report(self, dicts):
"""Report values from a given dict."""
for d in dicts:
chainer.reporter.report(d, self)
class TTSInterface(object):
"""TTS Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add model specific argments to parser."""
return parser
def __init__(self):
"""Initilize TTS module."""
self.reporter = Reporter()
def forward(self, *args, **kwargs):
"""Calculate TTS forward propagation.
Returns:
Tensor: Loss value.
"""
raise NotImplementedError("forward method is not implemented")
def inference(self, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Returns:
Tensor: The sequence of generated features (L, odim).
Tensor: The sequence of stop probabilities (L,).
Tensor: The sequence of attention weights (L, T).
"""
raise NotImplementedError("inference method is not implemented")
def calculate_all_attentions(self, *args, **kwargs):
"""Calculate TTS attention weights.
Args:
Tensor: Batch of attention weights (B, Lmax, Tmax).
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
def load_pretrained_model(self, model_path):
"""Load pretrained model parameters."""
torch_load(model_path, self)
@property
def attention_plot_class(self):
"""Plot attention weights."""
from espnet.asr.asr_utils import PlotAttentionReport
return PlotAttentionReport
@property
def base_plot_keys(self):
"""Return base key names to plot during training.
The keys should match what `chainer.reporter` reports.
if you add the key `loss`,
the reporter will report `main/loss` and `validation/main/loss` values.
also `loss.png` will be created as a figure visulizing `main/loss`
and `validation/main/loss` values.
Returns:
list[str]: Base keys to plot during training.
"""
return ["loss"]
|