File size: 4,153 Bytes
d5ee97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
# Copyright 2020 The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Auto Model modules."""

import logging
import warnings
import os
import copy

from collections import OrderedDict

from tensorflow_tts.configs import (
    FastSpeechConfig,
    FastSpeech2Config,
    MelGANGeneratorConfig,
    MultiBandMelGANGeneratorConfig,
    HifiGANGeneratorConfig,
    Tacotron2Config,
    ParallelWaveGANGeneratorConfig,
)

from tensorflow_tts.models import (
    TFMelGANGenerator,
    TFMBMelGANGenerator,
    TFHifiGANGenerator,
    TFParallelWaveGANGenerator,
)

from tensorflow_tts.inference.savable_models import (
    SavableTFFastSpeech,
    SavableTFFastSpeech2,
    SavableTFTacotron2
)
from tensorflow_tts.utils import CACHE_DIRECTORY, MODEL_FILE_NAME, LIBRARY_NAME
from tensorflow_tts import __version__ as VERSION
from huggingface_hub import hf_hub_url, cached_download


TF_MODEL_MAPPING = OrderedDict(
    [
        (FastSpeech2Config, SavableTFFastSpeech2),
        (FastSpeechConfig, SavableTFFastSpeech),
        (MultiBandMelGANGeneratorConfig, TFMBMelGANGenerator),
        (MelGANGeneratorConfig, TFMelGANGenerator),
        (Tacotron2Config, SavableTFTacotron2),
        (HifiGANGeneratorConfig, TFHifiGANGenerator),
        (ParallelWaveGANGeneratorConfig, TFParallelWaveGANGenerator),
    ]
)


class TFAutoModel(object):
    """General model class for inferencing."""

    def __init__(self):
        raise EnvironmentError("Cannot be instantiated using `__init__()`")

    @classmethod
    def from_pretrained(cls, pretrained_path=None, config=None, **kwargs):
        # load weights from hf hub
        if pretrained_path is not None:
            if not os.path.isfile(pretrained_path):
                # retrieve correct hub url
                download_url = hf_hub_url(repo_id=pretrained_path, filename=MODEL_FILE_NAME)

                downloaded_file = str(
                    cached_download(
                        url=download_url,
                        library_name=LIBRARY_NAME,
                        library_version=VERSION,
                        cache_dir=CACHE_DIRECTORY,
                    )
                )

                # load config from repo as well
                if config is None:
                    from tensorflow_tts.inference import AutoConfig

                    config = AutoConfig.from_pretrained(pretrained_path)

                pretrained_path = downloaded_file


        assert config is not None, "Please make sure to pass a config along to load a model from a local file"

        for config_class, model_class in TF_MODEL_MAPPING.items():
            if isinstance(config, config_class) and str(config_class.__name__) in str(
                config
            ):
                model = model_class(config=config, **kwargs)
                model.set_config(config)
                model._build()
                if pretrained_path is not None and ".h5" in pretrained_path:
                    try:
                        model.load_weights(pretrained_path)
                    except:
                        model.load_weights(
                            pretrained_path, by_name=True, skip_mismatch=True
                        )
                return model

        raise ValueError(
            "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys()),
            )
        )