File size: 2,017 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from coqpit import Coqpit

from TTS.model import BaseTrainerModel

# pylint: skip-file


class BaseVocoder(BaseTrainerModel):
    """Base `vocoder` class. Every new `vocoder` model must inherit this.

    It defines `vocoder` specific functions on top of `Model`.

    Notes on input/output tensor shapes:
        Any input or output tensor of the model must be shaped as

        - 3D tensors `batch x time x channels`
        - 2D tensors `batch x channels`
        - 1D tensors `batch x 1`
    """

    MODEL_TYPE = "vocoder"

    def __init__(self, config):
        super().__init__()
        self._set_model_args(config)

    def _set_model_args(self, config: Coqpit):
        """Setup model args based on the config type.

        If the config is for training with a name like "*Config", then the model args are embeded in the
        config.model_args

        If the config is for the model with a name like "*Args", then we assign the directly.
        """
        # don't use isintance not to import recursively
        if "Config" in config.__class__.__name__:
            if "characters" in config:
                _, self.config, num_chars = self.get_characters(config)
                self.config.num_chars = num_chars
                if hasattr(self.config, "model_args"):
                    config.model_args.num_chars = num_chars
                    if "model_args" in config:
                        self.args = self.config.model_args
                    # This is for backward compatibility
                    if "model_params" in config:
                        self.args = self.config.model_params
            else:
                self.config = config
                if "model_args" in config:
                    self.args = self.config.model_args
                # This is for backward compatibility
                if "model_params" in config:
                    self.args = self.config.model_params
        else:
            raise ValueError("config must be either a *Config or *Args")