Jofthomas HF staff commited on
Commit
bab6276
1 Parent(s): bbcd7e3

Create coqui.py

Browse files
Files changed (1) hide show
  1. coqui.py +389 -0
coqui.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import io, os, stat
3
+ import subprocess
4
+ import random
5
+ from zipfile import ZipFile
6
+ import uuid
7
+ import time
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+
12
+
13
+ #update gradio to faster streaming
14
+ #download for mecab
15
+ os.system('python -m unidic download')
16
+
17
+ # By using XTTS you agree to CPML license https://coqui.ai/cpml
18
+ os.environ["COQUI_TOS_AGREED"] = "1"
19
+
20
+ # langid is used to detect language for longer text
21
+ # Most users expect text to be their own language, there is checkbox to disable it
22
+ import langid
23
+ import base64
24
+ import csv
25
+ from io import StringIO
26
+ import datetime
27
+ import re
28
+
29
+ from scipy.io.wavfile import write
30
+ from pydub import AudioSegment
31
+
32
+ from TTS.api import TTS
33
+ from TTS.tts.configs.xtts_config import XttsConfig
34
+ from TTS.tts.models.xtts import Xtts
35
+ from TTS.utils.generic_utils import get_user_data_dir
36
+
37
+ HF_TOKEN = os.environ.get("HF_TOKEN")
38
+
39
+ from huggingface_hub import HfApi
40
+ # will use api to restart space on a unrecoverable error
41
+ api = HfApi(token=HF_TOKEN)
42
+ repo_id = "coqui/xtts"
43
+
44
+ # This will trigger downloading model
45
+ print("Downloading if not downloaded Coqui XTTS V2")
46
+ from TTS.utils.manage import ModelManager
47
+
48
+ model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
49
+ ModelManager().download_model(model_name)
50
+ model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
51
+ print("XTTS downloaded")
52
+
53
+ config = XttsConfig()
54
+ config.load_json(os.path.join(model_path, "config.json"))
55
+
56
+ model = Xtts.init_from_config(config)
57
+ model.load_checkpoint(
58
+ config,
59
+ checkpoint_path=os.path.join(model_path, "model.pth"),
60
+ vocab_path=os.path.join(model_path, "vocab.json"),
61
+ eval=True,
62
+ use_deepspeed=True,
63
+ )
64
+ model.cuda()
65
+
66
+ # This is for debugging purposes only
67
+ DEVICE_ASSERT_DETECTED = 0
68
+ DEVICE_ASSERT_PROMPT = None
69
+ DEVICE_ASSERT_LANG = None
70
+
71
+ supported_languages = config.languages
72
+ def numpy_to_mp3(audio_array, sampling_rate):
73
+ # Normalize audio_array if it's floating-point
74
+ if np.issubdtype(audio_array.dtype, np.floating):
75
+ max_val = np.max(np.abs(audio_array))
76
+ audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
77
+ audio_array = audio_array.astype(np.int16)
78
+
79
+ # Create an audio segment from the numpy array
80
+ audio_segment = AudioSegment(
81
+ audio_array.tobytes(),
82
+ frame_rate=sampling_rate,
83
+ sample_width=audio_array.dtype.itemsize,
84
+ channels=1
85
+ )
86
+
87
+ # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
88
+ mp3_io = io.BytesIO()
89
+ audio_segment.export(mp3_io, format="mp3", bitrate="320k")
90
+
91
+ # Get the MP3 bytes
92
+ mp3_bytes = mp3_io.getvalue()
93
+ mp3_io.close()
94
+
95
+ return mp3_bytes
96
+
97
+ def predict(
98
+ prompt,
99
+ language,
100
+ audio_file_pth,
101
+ mic_file_path,
102
+ use_mic,
103
+ voice_cleanup,
104
+ no_lang_auto_detect,
105
+ agree,
106
+ ):
107
+ if agree == True:
108
+ if language not in supported_languages:
109
+ gr.Warning(
110
+ f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
111
+ )
112
+
113
+ return (
114
+ None,
115
+ )
116
+
117
+ language_predicted = langid.classify(prompt)[
118
+ 0
119
+ ].strip() # strip need as there is space at end!
120
+
121
+ # tts expects chinese as zh-cn
122
+ if language_predicted == "zh":
123
+ # we use zh-cn
124
+ language_predicted = "zh-cn"
125
+
126
+ print(f"Detected language:{language_predicted}, Chosen language:{language}")
127
+
128
+ # After text character length 15 trigger language detection
129
+ if len(prompt) > 15:
130
+ # allow any language for short text as some may be common
131
+ # If user unchecks language autodetection it will not trigger
132
+ # You may remove this completely for own use
133
+ if language_predicted != language and not no_lang_auto_detect:
134
+ # Please duplicate and remove this check if you really want this
135
+ # Or auto-detector fails to identify language (which it can on pretty short text or mixed text)
136
+ gr.Warning(
137
+ f"It looks like your text isn’t the language you chose , if you’re sure the text is the same language you chose, please check disable language auto-detection checkbox"
138
+ )
139
+
140
+ return (
141
+ None,
142
+ )
143
+
144
+ if use_mic == True:
145
+ if mic_file_path is not None:
146
+ speaker_wav = mic_file_path
147
+ else:
148
+ gr.Warning(
149
+ "Please record your voice with Microphone, or uncheck Use Microphone to use reference audios"
150
+ )
151
+ return (
152
+ None,
153
+ )
154
+
155
+ else:
156
+ speaker_wav = audio_file_pth
157
+
158
+ # Filtering for microphone input, as it has BG noise, maybe silence in beginning and end
159
+ # This is fast filtering not perfect
160
+
161
+ # Apply all on demand
162
+ lowpassfilter = denoise = trim = loudness = True
163
+
164
+ if lowpassfilter:
165
+ lowpass_highpass = "lowpass=8000,highpass=75,"
166
+ else:
167
+ lowpass_highpass = ""
168
+
169
+ if trim:
170
+ # better to remove silence in beginning and end for microphone
171
+ trim_silence = "areverse,silenceremove=start_periods=1:start_silence=0:start_threshold=0.02,areverse,silenceremove=start_periods=1:start_silence=0:start_threshold=0.02,"
172
+ else:
173
+ trim_silence = ""
174
+
175
+ if voice_cleanup:
176
+ try:
177
+ out_filename = (
178
+ speaker_wav + str(uuid.uuid4()) + ".wav"
179
+ ) # ffmpeg to know output format
180
+
181
+ # we will use newer ffmpeg as that has afftn denoise filter
182
+ shell_command = f"./ffmpeg -y -i {speaker_wav} -af {lowpass_highpass}{trim_silence} {out_filename}".split(
183
+ " "
184
+ )
185
+
186
+ command_result = subprocess.run(
187
+ [item for item in shell_command],
188
+ capture_output=False,
189
+ text=True,
190
+ check=True,
191
+ )
192
+ speaker_wav = out_filename
193
+ print("Filtered microphone input")
194
+ except subprocess.CalledProcessError:
195
+ # There was an error - command exited with non-zero code
196
+ print("Error: failed filtering, use original microphone input")
197
+ else:
198
+ speaker_wav = speaker_wav
199
+
200
+ if len(prompt) < 2:
201
+ gr.Warning("Please give a longer prompt text")
202
+ return (
203
+ None,
204
+ )
205
+ if len(prompt) > 1000:
206
+ gr.Warning(
207
+ "Text length limited to 200 characters for this demo, please try shorter text. You can clone this space and edit code for your own usage"
208
+ )
209
+ return (
210
+ None,
211
+ )
212
+ global DEVICE_ASSERT_DETECTED
213
+ if DEVICE_ASSERT_DETECTED:
214
+ global DEVICE_ASSERT_PROMPT
215
+ global DEVICE_ASSERT_LANG
216
+ # It will likely never come here as we restart space on first unrecoverable error now
217
+ print(
218
+ f"Unrecoverable exception caused by language:{DEVICE_ASSERT_LANG} prompt:{DEVICE_ASSERT_PROMPT}"
219
+ )
220
+
221
+ # HF Space specific.. This error is unrecoverable need to restart space
222
+ space = api.get_space_runtime(repo_id=repo_id)
223
+ if space.stage != "BUILDING":
224
+ api.restart_space(repo_id=repo_id)
225
+ else:
226
+ print("TRIED TO RESTART but space is building")
227
+
228
+ try:
229
+ metrics_text = ""
230
+ t_latent = time.time()
231
+
232
+ # note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
233
+ try:
234
+ (
235
+ gpt_cond_latent,
236
+ speaker_embedding,
237
+ ) = model.get_conditioning_latents(audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60)
238
+ except Exception as e:
239
+ print("Speaker encoding error", str(e))
240
+ gr.Warning(
241
+ "It appears something wrong with reference, did you unmute your microphone?"
242
+ )
243
+ return (
244
+ None,
245
+ )
246
+
247
+ latent_calculation_time = time.time() - t_latent
248
+ # metrics_text=f"Embedding calculation time: {latent_calculation_time:.2f} seconds\n"
249
+
250
+ # temporary comma fix
251
+ prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
252
+
253
+ wav_chunks = []
254
+ ## Direct mode
255
+ """
256
+ print("I: Generating new audio...")
257
+ t0 = time.time()
258
+ out = model.inference(
259
+ prompt,
260
+ language,
261
+ gpt_cond_latent,
262
+ speaker_embedding,
263
+ repetition_penalty=5.0,
264
+ temperature=0.75,
265
+ )
266
+ inference_time = time.time() - t0
267
+ print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
268
+ metrics_text+=f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
269
+ real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000
270
+ print(f"Real-time factor (RTF): {real_time_factor}")
271
+ metrics_text+=f"Real-time factor (RTF): {real_time_factor:.2f}\n"
272
+ torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
273
+ """
274
+ print("I: Generating new audio in streaming mode...")
275
+ t0 = time.time()
276
+ chunks = model.inference_stream(
277
+ prompt,
278
+ language,
279
+ gpt_cond_latent,
280
+ speaker_embedding,
281
+ repetition_penalty=7.0,
282
+ temperature=0.85,
283
+ )
284
+
285
+ first_chunk = True
286
+ for i, chunk in enumerate(chunks):
287
+ if first_chunk:
288
+ first_chunk_time = time.time() - t0
289
+ metrics_text += f"Latency to first audio chunk: {round(first_chunk_time*1000)} milliseconds\n"
290
+ first_chunk = False
291
+
292
+ # Convert chunk to numpy array and return it
293
+ chunk_np = chunk.cpu().numpy()
294
+ print('chunk',i)
295
+ yield (24000, chunk_np)
296
+ wav_chunks.append(chunk)
297
+
298
+ print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
299
+ inference_time = time.time() - t0
300
+ print(
301
+ f"I: Time to generate audio: {round(inference_time*1000)} milliseconds"
302
+ )
303
+ # metrics_text += (
304
+ # f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
305
+ #)
306
+
307
+ except RuntimeError as e:
308
+ if "device-side assert" in str(e):
309
+ # cannot do anything on cuda device side error, need tor estart
310
+ print(
311
+ f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
312
+ flush=True,
313
+ )
314
+ gr.Warning("Unhandled Exception encounter, please retry in a minute")
315
+ print("Cuda device-assert Runtime encountered need restart")
316
+ if not DEVICE_ASSERT_DETECTED:
317
+ DEVICE_ASSERT_DETECTED = 1
318
+ DEVICE_ASSERT_PROMPT = prompt
319
+ DEVICE_ASSERT_LANG = language
320
+
321
+ # just before restarting save what caused the issue so we can handle it in future
322
+ # Uploading Error data only happens for unrecovarable error
323
+ error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
324
+ error_data = [
325
+ error_time,
326
+ prompt,
327
+ language,
328
+ audio_file_pth,
329
+ mic_file_path,
330
+ use_mic,
331
+ voice_cleanup,
332
+ no_lang_auto_detect,
333
+ agree,
334
+ ]
335
+ error_data = [str(e) if type(e) != str else e for e in error_data]
336
+ print(error_data)
337
+ print(speaker_wav)
338
+ write_io = StringIO()
339
+ csv.writer(write_io).writerows([error_data])
340
+ csv_upload = write_io.getvalue().encode()
341
+
342
+ filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
343
+ print("Writing error csv")
344
+ error_api = HfApi()
345
+ error_api.upload_file(
346
+ path_or_fileobj=csv_upload,
347
+ path_in_repo=filename,
348
+ repo_id="coqui/xtts-flagged-dataset",
349
+ repo_type="dataset",
350
+ )
351
+
352
+ # speaker_wav
353
+ print("Writing error reference audio")
354
+ speaker_filename = (
355
+ error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
356
+ )
357
+ error_api = HfApi()
358
+ error_api.upload_file(
359
+ path_or_fileobj=speaker_wav,
360
+ path_in_repo=speaker_filename,
361
+ repo_id="coqui/xtts-flagged-dataset",
362
+ repo_type="dataset",
363
+ )
364
+
365
+ # HF Space specific.. This error is unrecoverable need to restart space
366
+ space = api.get_space_runtime(repo_id=repo_id)
367
+ if space.stage != "BUILDING":
368
+ api.restart_space(repo_id=repo_id)
369
+ else:
370
+ print("TRIED TO RESTART but space is building")
371
+
372
+ else:
373
+ if "Failed to decode" in str(e):
374
+ print("Speaker encoding error", str(e))
375
+ gr.Warning(
376
+ "It appears something wrong with reference, did you unmute your microphone?"
377
+ )
378
+ else:
379
+ print("RuntimeError: non device-side assert error:", str(e))
380
+ gr.Warning("Something unexpected happened please retry again.")
381
+ return (
382
+ None,
383
+ )
384
+
385
+ else:
386
+ gr.Warning("Please accept the Terms & Condition!")
387
+ return (
388
+ None,
389
+ )