File size: 5,945 Bytes
c02bdcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import random
from typing import Optional
from time import sleep

import gradio as gr

from tools.audio import float_to_int16, has_ffmpeg_installed, load_audio
from tools.logger import get_logger

logger = get_logger(" WebUI ")

from tools.seeder import TorchSeedContext
from tools.normalizer import normalizer_en_nemo_text, normalizer_zh_tn

import ChatTTS

chat = ChatTTS.Chat(get_logger("ChatTTS"))

custom_path: Optional[str] = None

has_interrupted = False
is_in_generate = False

seed_min = 1
seed_max = 4294967295

use_mp3 = has_ffmpeg_installed()
if not use_mp3:
    logger.warning("no ffmpeg installed, use wav file output")

# 音色选项:用于预置合适的音色
voices = {
    "Default": {"seed": 2},
    "Timbre1": {"seed": 1111},
    "Timbre2": {"seed": 2222},
    "Timbre3": {"seed": 3333},
    "Timbre4": {"seed": 4444},
    "Timbre5": {"seed": 5555},
    "Timbre6": {"seed": 6666},
    "Timbre7": {"seed": 7777},
    "Timbre8": {"seed": 8888},
    "Timbre9": {"seed": 9999},
}


def generate_seed():
    return gr.update(value=random.randint(seed_min, seed_max))


# 返回选择音色对应的seed
def on_voice_change(vocie_selection):
    return voices.get(vocie_selection)["seed"]


def on_audio_seed_change(audio_seed_input):
    with TorchSeedContext(audio_seed_input):
        rand_spk = chat.sample_random_speaker()
    return rand_spk


def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
    if cust_path == None:
        ret = chat.load(coef=coef)
    else:
        logger.info("local model path: %s", cust_path)
        ret = chat.load("custom", custom_path=cust_path, coef=coef)
        global custom_path
        custom_path = cust_path
    if ret:
        try:
            chat.normalizer.register("en", normalizer_en_nemo_text())
        except ValueError as e:
            logger.error(e)
        except:
            logger.warning("Package nemo_text_processing not found!")
            logger.warning(
                "Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
            )
        try:
            chat.normalizer.register("zh", normalizer_zh_tn())
        except ValueError as e:
            logger.error(e)
        except:
            logger.warning("Package WeTextProcessing not found!")
            logger.warning(
                "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
            )
    return ret


def reload_chat(coef: Optional[str]) -> str:
    global is_in_generate

    if is_in_generate:
        gr.Warning("Cannot reload when generating!")
        return coef

    chat.unload()
    gr.Info("Model unloaded.")
    if len(coef) != 230:
        gr.Warning("Ingore invalid DVAE coefficient.")
        coef = None
    try:
        global custom_path
        ret = load_chat(custom_path, coef)
    except Exception as e:
        raise gr.Error(str(e))
    if not ret:
        raise gr.Error("Unable to load model.")
    gr.Info("Reload succeess.")
    return chat.coef


def on_upload_sample_audio(sample_audio_input: Optional[str]) -> str:
    if sample_audio_input is None:
        return ""
    sample_audio = load_audio(sample_audio_input, 24000)
    spk_smp = chat.sample_audio_speaker(sample_audio)
    del sample_audio
    return spk_smp


def _set_generate_buttons(generate_button, interrupt_button, is_reset=False):
    return gr.update(
        value=generate_button, visible=is_reset, interactive=is_reset
    ), gr.update(value=interrupt_button, visible=not is_reset, interactive=not is_reset)


def refine_text(
    text,
    text_seed_input,
    refine_text_flag,
    temperature,
    top_P,
    top_K,
):
    global chat

    if not refine_text_flag:
        sleep(1)  # to skip fast answer of loading mark
        return text

    text = chat.infer(
        text,
        skip_refine_text=False,
        refine_text_only=True,
        params_refine_text=ChatTTS.Chat.RefineTextParams(
            temperature=temperature,
            top_P=top_P,
            top_K=top_K,
            manual_seed=text_seed_input,
        ),
    )

    return text[0] if isinstance(text, list) else text


def generate_audio(
    text,
    temperature,
    top_P,
    top_K,
    spk_emb_text: str,
    stream,
    audio_seed_input,
    sample_text_input,
    sample_audio_code_input,
):
    global chat, has_interrupted

    if not text or has_interrupted or not spk_emb_text.startswith("蘁淰"):
        return None

    params_infer_code = ChatTTS.Chat.InferCodeParams(
        spk_emb=spk_emb_text,
        temperature=temperature,
        top_P=top_P,
        top_K=top_K,
        manual_seed=audio_seed_input,
    )

    if sample_text_input and sample_audio_code_input:
        params_infer_code.txt_smp = sample_text_input
        params_infer_code.spk_smp = sample_audio_code_input
        params_infer_code.spk_emb = None

    wav = chat.infer(
        text,
        skip_refine_text=True,
        params_infer_code=params_infer_code,
        stream=stream,
    )
    if stream:
        for gen in wav:
            audio = gen[0]
            if audio is not None and len(audio) > 0:
                yield 24000, float_to_int16(audio).T
            del audio
    else:
        yield 24000, float_to_int16(wav[0]).T


def interrupt_generate():
    global chat, has_interrupted

    has_interrupted = True
    chat.interrupt()


def set_buttons_before_generate(generate_button, interrupt_button):
    global has_interrupted, is_in_generate

    has_interrupted = False
    is_in_generate = True

    return _set_generate_buttons(
        generate_button,
        interrupt_button,
    )


def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
    global has_interrupted, is_in_generate

    is_in_generate = False

    return _set_generate_buttons(
        generate_button,
        interrupt_button,
        audio_output is not None or has_interrupted,
    )