Upload 86 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +23 -0
- VietTTS/__init__.py +0 -0
- VietTTS/cli.py +114 -0
- VietTTS/flow/decoder.py +649 -0
- VietTTS/flow/flow.py +158 -0
- VietTTS/flow/flow_matching.py +268 -0
- VietTTS/flow/length_regulator.py +56 -0
- VietTTS/frontend.py +151 -0
- VietTTS/hifigan/f0_predictor.py +42 -0
- VietTTS/hifigan/generator.py +384 -0
- VietTTS/llm/llm.py +199 -0
- VietTTS/model.py +260 -0
- VietTTS/models/.cache/huggingface/.gitignore +1 -0
- VietTTS/models/.cache/huggingface/download/.gitattributes.lock +0 -0
- VietTTS/models/.cache/huggingface/download/.gitattributes.metadata +3 -0
- VietTTS/models/.cache/huggingface/download/README.md.lock +0 -0
- VietTTS/models/.cache/huggingface/download/README.md.metadata +3 -0
- VietTTS/models/.cache/huggingface/download/README_VN.md.lock +0 -0
- VietTTS/models/.cache/huggingface/download/README_VN.md.metadata +3 -0
- VietTTS/models/.cache/huggingface/download/config.yaml.lock +0 -0
- VietTTS/models/.cache/huggingface/download/config.yaml.metadata +3 -0
- VietTTS/models/.cache/huggingface/download/flow.pt.lock +0 -0
- VietTTS/models/.cache/huggingface/download/flow.pt.metadata +3 -0
- VietTTS/models/.cache/huggingface/download/hift.pt.lock +0 -0
- VietTTS/models/.cache/huggingface/download/hift.pt.metadata +3 -0
- VietTTS/models/.cache/huggingface/download/llm.pt.lock +0 -0
- VietTTS/models/.cache/huggingface/download/llm.pt.metadata +3 -0
- VietTTS/models/.cache/huggingface/download/speech_embedding.onnx.lock +0 -0
- VietTTS/models/.cache/huggingface/download/speech_embedding.onnx.metadata +3 -0
- VietTTS/models/.cache/huggingface/download/speech_tokenizer.onnx.lock +0 -0
- VietTTS/models/.cache/huggingface/download/speech_tokenizer.onnx.metadata +3 -0
- VietTTS/models/.gitattributes +35 -0
- VietTTS/models/README.md +213 -0
- VietTTS/models/README_VN.md +203 -0
- VietTTS/models/config.yaml +129 -0
- VietTTS/models/flow.pt +3 -0
- VietTTS/models/hift.pt +3 -0
- VietTTS/models/llm.pt +3 -0
- VietTTS/models/speech_embedding.onnx +3 -0
- VietTTS/models/speech_tokenizer.onnx +3 -0
- VietTTS/samples/cdteam.wav +3 -0
- VietTTS/samples/cross_lingual_prompt.wav +3 -0
- VietTTS/samples/diep-chi.wav +3 -0
- VietTTS/samples/doremon.mp3 +3 -0
- VietTTS/samples/jack-sparrow.mp3 +3 -0
- VietTTS/samples/nguyen-ngoc-ngan.wav +3 -0
- VietTTS/samples/nsnd-le-chuc.mp3 +3 -0
- VietTTS/samples/nu-nhe-nhang.wav +3 -0
- VietTTS/samples/quynh.wav +3 -0
- VietTTS/samples/son-tung-mtp.wav +3 -0
.gitattributes
CHANGED
@@ -60,3 +60,26 @@ Vinorm/vinorm/lib/libicuuc.so filter=lfs diff=lfs merge=lfs -text
|
|
60 |
Vinorm/vinorm/lib/libicuuc.so.64 filter=lfs diff=lfs merge=lfs -text
|
61 |
Vinorm/vinorm/lib/libicuuc.so.64.2 filter=lfs diff=lfs merge=lfs -text
|
62 |
Vinorm/vinorm/main filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
Vinorm/vinorm/lib/libicuuc.so.64 filter=lfs diff=lfs merge=lfs -text
|
61 |
Vinorm/vinorm/lib/libicuuc.so.64.2 filter=lfs diff=lfs merge=lfs -text
|
62 |
Vinorm/vinorm/main filter=lfs diff=lfs merge=lfs -text
|
63 |
+
VietTTS/samples/cdteam.wav filter=lfs diff=lfs merge=lfs -text
|
64 |
+
VietTTS/samples/cross_lingual_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
65 |
+
VietTTS/samples/diep-chi.wav filter=lfs diff=lfs merge=lfs -text
|
66 |
+
VietTTS/samples/doremon.mp3 filter=lfs diff=lfs merge=lfs -text
|
67 |
+
VietTTS/samples/jack-sparrow.mp3 filter=lfs diff=lfs merge=lfs -text
|
68 |
+
VietTTS/samples/nguyen-ngoc-ngan.wav filter=lfs diff=lfs merge=lfs -text
|
69 |
+
VietTTS/samples/nsnd-le-chuc.mp3 filter=lfs diff=lfs merge=lfs -text
|
70 |
+
VietTTS/samples/nu-nhe-nhang.wav filter=lfs diff=lfs merge=lfs -text
|
71 |
+
VietTTS/samples/quynh.wav filter=lfs diff=lfs merge=lfs -text
|
72 |
+
VietTTS/samples/son-tung-mtp.wav filter=lfs diff=lfs merge=lfs -text
|
73 |
+
VietTTS/samples/speechify_1.wav filter=lfs diff=lfs merge=lfs -text
|
74 |
+
VietTTS/samples/speechify_10.wav filter=lfs diff=lfs merge=lfs -text
|
75 |
+
VietTTS/samples/speechify_11.wav filter=lfs diff=lfs merge=lfs -text
|
76 |
+
VietTTS/samples/speechify_12.wav filter=lfs diff=lfs merge=lfs -text
|
77 |
+
VietTTS/samples/speechify_2.wav filter=lfs diff=lfs merge=lfs -text
|
78 |
+
VietTTS/samples/speechify_3.wav filter=lfs diff=lfs merge=lfs -text
|
79 |
+
VietTTS/samples/speechify_4.wav filter=lfs diff=lfs merge=lfs -text
|
80 |
+
VietTTS/samples/speechify_5.wav filter=lfs diff=lfs merge=lfs -text
|
81 |
+
VietTTS/samples/speechify_6.wav filter=lfs diff=lfs merge=lfs -text
|
82 |
+
VietTTS/samples/speechify_7.wav filter=lfs diff=lfs merge=lfs -text
|
83 |
+
VietTTS/samples/speechify_8.wav filter=lfs diff=lfs merge=lfs -text
|
84 |
+
VietTTS/samples/speechify_9.wav filter=lfs diff=lfs merge=lfs -text
|
85 |
+
VietTTS/samples/zero_shot_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
VietTTS/__init__.py
ADDED
File without changes
|
VietTTS/cli.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import click
|
5 |
+
import subprocess
|
6 |
+
from loguru import logger
|
7 |
+
from rich.table import Table
|
8 |
+
from rich.console import Console
|
9 |
+
from VietTTS.tts import TTS
|
10 |
+
from VietTTS.utils.file_utils import load_prompt_speech_from_file, load_voices
|
11 |
+
|
12 |
+
|
13 |
+
AUDIO_DIR = 'samples'
|
14 |
+
MODEL_DIR = 'pretrained-models'
|
15 |
+
|
16 |
+
@click.command('server')
|
17 |
+
@click.option('-h', '--host', type=str, default='0.0.0.0', help="The host address to bind the server to. Default is '0.0.0.0'.")
|
18 |
+
@click.option('-p', '--port', type=int, default=8298, help="The port number to bind the server to. Default is 8298.")
|
19 |
+
@click.option('-w', '--workers', type=int, default=1, help="The number of worker processes to handle requests. Default is 1.")
|
20 |
+
def start_server(host: str, port: int, workers: int):
|
21 |
+
"""Start API server (OpenAI TTS API compatible).
|
22 |
+
|
23 |
+
Usage: viettts server --host 0.0.0.0 --port 8298 -w 4
|
24 |
+
"""
|
25 |
+
logger.info("Starting server")
|
26 |
+
cmd = f'gunicorn viettts.server:app \
|
27 |
+
-k uvicorn.workers.UvicornWorker \
|
28 |
+
--bind {host}:{port} \
|
29 |
+
--workers {workers} \
|
30 |
+
--max-requests 1000 \
|
31 |
+
--max-requests-jitter 50 \
|
32 |
+
--timeout 300 \
|
33 |
+
--keep-alive 75 \
|
34 |
+
--graceful-timeout 60'
|
35 |
+
|
36 |
+
subprocess.call(cmd, shell=True, stdout=sys.stdout)
|
37 |
+
|
38 |
+
|
39 |
+
@click.command('synthesis')
|
40 |
+
@click.option('-t', "--text", type=str, required=True, help="The input text to synthesize into speech.")
|
41 |
+
@click.option('-v', "--voice", type=str, default='1', help="The voice ID or file path to clone the voice from. Default is '1'.")
|
42 |
+
@click.option('-s', "--speed", type=float, default=1, help="The speed multiplier for the speech. Default is 1 (normal speed).")
|
43 |
+
@click.option('-o', "--output", type=str, default='output.wav', help="The file path to save the synthesized audio. Default is 'output.wav'.")
|
44 |
+
def synthesis(text: str, voice: str, speed: float, output: str):
|
45 |
+
"""Synthesis audio from text and save to file.
|
46 |
+
|
47 |
+
Usage: viettts synthesis --text 'Xin chào VietTTS' --voice nu-nhe-nhang --voice 8 --speed 1.2 --output test_nu-nhe-nhang.wav
|
48 |
+
"""
|
49 |
+
logger.info("Starting synthesis")
|
50 |
+
st = time.perf_counter()
|
51 |
+
if not text:
|
52 |
+
logger.error('text must not empty')
|
53 |
+
return
|
54 |
+
|
55 |
+
if speed > 2 or speed < 0.5:
|
56 |
+
logger.error(f'speed must in range 0.5-2.0')
|
57 |
+
return
|
58 |
+
|
59 |
+
if not os.path.exists(voice):
|
60 |
+
voice_map = load_voices(AUDIO_DIR)
|
61 |
+
if voice.isdigit():
|
62 |
+
voice = list(voice_map.values())[int(voice)]
|
63 |
+
else:
|
64 |
+
voice = voice_map.get(voice)
|
65 |
+
|
66 |
+
if not os.path.exists(voice):
|
67 |
+
logger.error(f'voice is not available. Use --voice <voice-name/voice-id/local-file> or run `viettts show-voices` to get available voices.')
|
68 |
+
return
|
69 |
+
|
70 |
+
logger.info('Loading model')
|
71 |
+
tts = TTS(model_dir=MODEL_DIR)
|
72 |
+
|
73 |
+
logger.info('Loading voice')
|
74 |
+
voice = load_prompt_speech_from_file(voice)
|
75 |
+
|
76 |
+
logger.info('Processing')
|
77 |
+
tts.tts_to_file(text, voice, speed, output)
|
78 |
+
|
79 |
+
et = time.perf_counter()
|
80 |
+
logger.success(f"Saved to: {output} [time cost={et-st:.2f}s]")
|
81 |
+
|
82 |
+
|
83 |
+
@click.command('show-voices')
|
84 |
+
def show_voice():
|
85 |
+
"""Print all available voices.
|
86 |
+
|
87 |
+
Usage: viettts show-voices
|
88 |
+
"""
|
89 |
+
voice_map = load_voices(AUDIO_DIR)
|
90 |
+
console = Console()
|
91 |
+
table = Table(show_header=True, header_style="green", show_lines=False)
|
92 |
+
table.add_column("Voice ID", width=10)
|
93 |
+
table.add_column("Voice Name", width=30)
|
94 |
+
table.add_column("File", justify="left")
|
95 |
+
|
96 |
+
for i, (voice_name, voice_path) in enumerate(voice_map.items()):
|
97 |
+
table.add_row(str(i+1), voice_name, voice_path)
|
98 |
+
|
99 |
+
console.print(table)
|
100 |
+
|
101 |
+
|
102 |
+
@click.group()
|
103 |
+
def cli():
|
104 |
+
"""
|
105 |
+
VietTTS CLI v0.1.0
|
106 |
+
|
107 |
+
Vietnamese Text To Speech and Voice Clone
|
108 |
+
License: Apache 2.0 - Author: <dangvansam [email protected]>
|
109 |
+
"""
|
110 |
+
pass
|
111 |
+
|
112 |
+
cli.add_command(start_server)
|
113 |
+
cli.add_command(synthesis)
|
114 |
+
cli.add_command(show_voice)
|
VietTTS/flow/decoder.py
ADDED
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import pack, rearrange, repeat
|
4 |
+
|
5 |
+
import math
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from conformer import ConformerBlock
|
12 |
+
from diffusers.models.activations import get_activation
|
13 |
+
|
14 |
+
from VietTTS.transformer.transformer import BasicTransformerBlock
|
15 |
+
|
16 |
+
|
17 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
18 |
+
def __init__(self, dim):
|
19 |
+
super().__init__()
|
20 |
+
self.dim = dim
|
21 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
22 |
+
|
23 |
+
def forward(self, x, scale=1000):
|
24 |
+
if x.ndim < 1:
|
25 |
+
x = x.unsqueeze(0)
|
26 |
+
device = x.device
|
27 |
+
half_dim = self.dim // 2
|
28 |
+
emb = math.log(10000) / (half_dim - 1)
|
29 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
30 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
31 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
32 |
+
return emb
|
33 |
+
|
34 |
+
|
35 |
+
class Block1D(torch.nn.Module):
|
36 |
+
def __init__(self, dim, dim_out, groups=8):
|
37 |
+
super().__init__()
|
38 |
+
self.block = torch.nn.Sequential(
|
39 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
40 |
+
torch.nn.GroupNorm(groups, dim_out),
|
41 |
+
nn.Mish(),
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, x, mask):
|
45 |
+
output = self.block(x * mask)
|
46 |
+
return output * mask
|
47 |
+
|
48 |
+
|
49 |
+
class ResnetBlock1D(torch.nn.Module):
|
50 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
51 |
+
super().__init__()
|
52 |
+
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
53 |
+
|
54 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
55 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
56 |
+
|
57 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
58 |
+
|
59 |
+
def forward(self, x, mask, time_emb):
|
60 |
+
h = self.block1(x, mask)
|
61 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
62 |
+
h = self.block2(h, mask)
|
63 |
+
output = h + self.res_conv(x * mask)
|
64 |
+
return output
|
65 |
+
|
66 |
+
|
67 |
+
class Downsample1D(nn.Module):
|
68 |
+
def __init__(self, dim):
|
69 |
+
super().__init__()
|
70 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
return self.conv(x)
|
74 |
+
|
75 |
+
|
76 |
+
class TimestepEmbedding(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
in_channels: int,
|
80 |
+
time_embed_dim: int,
|
81 |
+
act_fn: str = "silu",
|
82 |
+
out_dim: int = None,
|
83 |
+
post_act_fn: Optional[str] = None,
|
84 |
+
cond_proj_dim=None,
|
85 |
+
):
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
89 |
+
|
90 |
+
if cond_proj_dim is not None:
|
91 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
92 |
+
else:
|
93 |
+
self.cond_proj = None
|
94 |
+
|
95 |
+
self.act = get_activation(act_fn)
|
96 |
+
|
97 |
+
if out_dim is not None:
|
98 |
+
time_embed_dim_out = out_dim
|
99 |
+
else:
|
100 |
+
time_embed_dim_out = time_embed_dim
|
101 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
102 |
+
|
103 |
+
if post_act_fn is None:
|
104 |
+
self.post_act = None
|
105 |
+
else:
|
106 |
+
self.post_act = get_activation(post_act_fn)
|
107 |
+
|
108 |
+
def forward(self, sample, condition=None):
|
109 |
+
if condition is not None:
|
110 |
+
sample = sample + self.cond_proj(condition)
|
111 |
+
sample = self.linear_1(sample)
|
112 |
+
|
113 |
+
if self.act is not None:
|
114 |
+
sample = self.act(sample)
|
115 |
+
|
116 |
+
sample = self.linear_2(sample)
|
117 |
+
|
118 |
+
if self.post_act is not None:
|
119 |
+
sample = self.post_act(sample)
|
120 |
+
return sample
|
121 |
+
|
122 |
+
|
123 |
+
class Upsample1D(nn.Module):
|
124 |
+
"""A 1D upsampling layer with an optional convolution.
|
125 |
+
|
126 |
+
Parameters:
|
127 |
+
channels (`int`):
|
128 |
+
number of channels in the inputs and outputs.
|
129 |
+
use_conv (`bool`, default `False`):
|
130 |
+
option to use a convolution.
|
131 |
+
use_conv_transpose (`bool`, default `False`):
|
132 |
+
option to use a convolution transpose.
|
133 |
+
out_channels (`int`, optional):
|
134 |
+
number of output channels. Defaults to `channels`.
|
135 |
+
"""
|
136 |
+
|
137 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
138 |
+
super().__init__()
|
139 |
+
self.channels = channels
|
140 |
+
self.out_channels = out_channels or channels
|
141 |
+
self.use_conv = use_conv
|
142 |
+
self.use_conv_transpose = use_conv_transpose
|
143 |
+
self.name = name
|
144 |
+
|
145 |
+
self.conv = None
|
146 |
+
if use_conv_transpose:
|
147 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
148 |
+
elif use_conv:
|
149 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
150 |
+
|
151 |
+
def forward(self, inputs):
|
152 |
+
assert inputs.shape[1] == self.channels
|
153 |
+
if self.use_conv_transpose:
|
154 |
+
return self.conv(inputs)
|
155 |
+
|
156 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
157 |
+
|
158 |
+
if self.use_conv:
|
159 |
+
outputs = self.conv(outputs)
|
160 |
+
|
161 |
+
return outputs
|
162 |
+
|
163 |
+
|
164 |
+
class ConformerWrapper(ConformerBlock):
|
165 |
+
def __init__( # pylint: disable=useless-super-delegation
|
166 |
+
self,
|
167 |
+
*,
|
168 |
+
dim,
|
169 |
+
dim_head=64,
|
170 |
+
heads=8,
|
171 |
+
ff_mult=4,
|
172 |
+
conv_expansion_factor=2,
|
173 |
+
conv_kernel_size=31,
|
174 |
+
attn_dropout=0,
|
175 |
+
ff_dropout=0,
|
176 |
+
conv_dropout=0,
|
177 |
+
conv_causal=False,
|
178 |
+
):
|
179 |
+
super().__init__(
|
180 |
+
dim=dim,
|
181 |
+
dim_head=dim_head,
|
182 |
+
heads=heads,
|
183 |
+
ff_mult=ff_mult,
|
184 |
+
conv_expansion_factor=conv_expansion_factor,
|
185 |
+
conv_kernel_size=conv_kernel_size,
|
186 |
+
attn_dropout=attn_dropout,
|
187 |
+
ff_dropout=ff_dropout,
|
188 |
+
conv_dropout=conv_dropout,
|
189 |
+
conv_causal=conv_causal,
|
190 |
+
)
|
191 |
+
|
192 |
+
def forward(
|
193 |
+
self,
|
194 |
+
hidden_states,
|
195 |
+
attention_mask,
|
196 |
+
encoder_hidden_states=None,
|
197 |
+
encoder_attention_mask=None,
|
198 |
+
timestep=None,
|
199 |
+
):
|
200 |
+
return super().forward(x=hidden_states, mask=attention_mask.bool())
|
201 |
+
|
202 |
+
|
203 |
+
class Decoder(nn.Module):
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
in_channels,
|
207 |
+
out_channels,
|
208 |
+
channels=(256, 256),
|
209 |
+
dropout=0.05,
|
210 |
+
attention_head_dim=64,
|
211 |
+
n_blocks=1,
|
212 |
+
num_mid_blocks=2,
|
213 |
+
num_heads=4,
|
214 |
+
act_fn="snake",
|
215 |
+
down_block_type="transformer",
|
216 |
+
mid_block_type="transformer",
|
217 |
+
up_block_type="transformer",
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
channels = tuple(channels)
|
221 |
+
self.in_channels = in_channels
|
222 |
+
self.out_channels = out_channels
|
223 |
+
|
224 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
225 |
+
time_embed_dim = channels[0] * 4
|
226 |
+
self.time_mlp = TimestepEmbedding(
|
227 |
+
in_channels=in_channels,
|
228 |
+
time_embed_dim=time_embed_dim,
|
229 |
+
act_fn="silu",
|
230 |
+
)
|
231 |
+
|
232 |
+
self.down_blocks = nn.ModuleList([])
|
233 |
+
self.mid_blocks = nn.ModuleList([])
|
234 |
+
self.up_blocks = nn.ModuleList([])
|
235 |
+
|
236 |
+
output_channel = in_channels
|
237 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
238 |
+
input_channel = output_channel
|
239 |
+
output_channel = channels[i]
|
240 |
+
is_last = i == len(channels) - 1
|
241 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
242 |
+
transformer_blocks = nn.ModuleList(
|
243 |
+
[
|
244 |
+
self.get_block(
|
245 |
+
down_block_type,
|
246 |
+
output_channel,
|
247 |
+
attention_head_dim,
|
248 |
+
num_heads,
|
249 |
+
dropout,
|
250 |
+
act_fn,
|
251 |
+
)
|
252 |
+
for _ in range(n_blocks)
|
253 |
+
]
|
254 |
+
)
|
255 |
+
downsample = (
|
256 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
257 |
+
)
|
258 |
+
|
259 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
260 |
+
|
261 |
+
for i in range(num_mid_blocks):
|
262 |
+
input_channel = channels[-1]
|
263 |
+
out_channels = channels[-1]
|
264 |
+
|
265 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
266 |
+
|
267 |
+
transformer_blocks = nn.ModuleList(
|
268 |
+
[
|
269 |
+
self.get_block(
|
270 |
+
mid_block_type,
|
271 |
+
output_channel,
|
272 |
+
attention_head_dim,
|
273 |
+
num_heads,
|
274 |
+
dropout,
|
275 |
+
act_fn,
|
276 |
+
)
|
277 |
+
for _ in range(n_blocks)
|
278 |
+
]
|
279 |
+
)
|
280 |
+
|
281 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
282 |
+
|
283 |
+
channels = channels[::-1] + (channels[0],)
|
284 |
+
for i in range(len(channels) - 1):
|
285 |
+
input_channel = channels[i]
|
286 |
+
output_channel = channels[i + 1]
|
287 |
+
is_last = i == len(channels) - 2
|
288 |
+
|
289 |
+
resnet = ResnetBlock1D(
|
290 |
+
dim=2 * input_channel,
|
291 |
+
dim_out=output_channel,
|
292 |
+
time_emb_dim=time_embed_dim,
|
293 |
+
)
|
294 |
+
transformer_blocks = nn.ModuleList(
|
295 |
+
[
|
296 |
+
self.get_block(
|
297 |
+
up_block_type,
|
298 |
+
output_channel,
|
299 |
+
attention_head_dim,
|
300 |
+
num_heads,
|
301 |
+
dropout,
|
302 |
+
act_fn,
|
303 |
+
)
|
304 |
+
for _ in range(n_blocks)
|
305 |
+
]
|
306 |
+
)
|
307 |
+
upsample = (
|
308 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
309 |
+
if not is_last
|
310 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
311 |
+
)
|
312 |
+
|
313 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
314 |
+
|
315 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
316 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
317 |
+
|
318 |
+
self.initialize_weights()
|
319 |
+
# nn.init.normal_(self.final_proj.weight)
|
320 |
+
|
321 |
+
@staticmethod
|
322 |
+
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
|
323 |
+
if block_type == "conformer":
|
324 |
+
block = ConformerWrapper(
|
325 |
+
dim=dim,
|
326 |
+
dim_head=attention_head_dim,
|
327 |
+
heads=num_heads,
|
328 |
+
ff_mult=1,
|
329 |
+
conv_expansion_factor=2,
|
330 |
+
ff_dropout=dropout,
|
331 |
+
attn_dropout=dropout,
|
332 |
+
conv_dropout=dropout,
|
333 |
+
conv_kernel_size=31,
|
334 |
+
)
|
335 |
+
elif block_type == "transformer":
|
336 |
+
block = BasicTransformerBlock(
|
337 |
+
dim=dim,
|
338 |
+
num_attention_heads=num_heads,
|
339 |
+
attention_head_dim=attention_head_dim,
|
340 |
+
dropout=dropout,
|
341 |
+
activation_fn=act_fn,
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
raise ValueError(f"Unknown block type {block_type}")
|
345 |
+
|
346 |
+
return block
|
347 |
+
|
348 |
+
def initialize_weights(self):
|
349 |
+
for m in self.modules():
|
350 |
+
if isinstance(m, nn.Conv1d):
|
351 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
352 |
+
|
353 |
+
if m.bias is not None:
|
354 |
+
nn.init.constant_(m.bias, 0)
|
355 |
+
|
356 |
+
elif isinstance(m, nn.GroupNorm):
|
357 |
+
nn.init.constant_(m.weight, 1)
|
358 |
+
nn.init.constant_(m.bias, 0)
|
359 |
+
|
360 |
+
elif isinstance(m, nn.Linear):
|
361 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
362 |
+
|
363 |
+
if m.bias is not None:
|
364 |
+
nn.init.constant_(m.bias, 0)
|
365 |
+
|
366 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
367 |
+
"""Forward pass of the UNet1DConditional model.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
371 |
+
mask (_type_): shape (batch_size, 1, time)
|
372 |
+
t (_type_): shape (batch_size)
|
373 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
374 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
375 |
+
|
376 |
+
Raises:
|
377 |
+
ValueError: _description_
|
378 |
+
ValueError: _description_
|
379 |
+
|
380 |
+
Returns:
|
381 |
+
_type_: _description_
|
382 |
+
"""
|
383 |
+
|
384 |
+
t = self.time_embeddings(t)
|
385 |
+
t = self.time_mlp(t)
|
386 |
+
|
387 |
+
x = pack([x, mu], "b * t")[0]
|
388 |
+
|
389 |
+
if spks is not None:
|
390 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
391 |
+
x = pack([x, spks], "b * t")[0]
|
392 |
+
|
393 |
+
hiddens = []
|
394 |
+
masks = [mask]
|
395 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
396 |
+
mask_down = masks[-1]
|
397 |
+
x = resnet(x, mask_down, t)
|
398 |
+
x = rearrange(x, "b c t -> b t c")
|
399 |
+
mask_down = rearrange(mask_down, "b 1 t -> b t")
|
400 |
+
for transformer_block in transformer_blocks:
|
401 |
+
x = transformer_block(
|
402 |
+
hidden_states=x,
|
403 |
+
attention_mask=mask_down,
|
404 |
+
timestep=t,
|
405 |
+
)
|
406 |
+
x = rearrange(x, "b t c -> b c t")
|
407 |
+
mask_down = rearrange(mask_down, "b t -> b 1 t")
|
408 |
+
hiddens.append(x) # Save hidden states for skip connections
|
409 |
+
x = downsample(x * mask_down)
|
410 |
+
masks.append(mask_down[:, :, ::2])
|
411 |
+
|
412 |
+
masks = masks[:-1]
|
413 |
+
mask_mid = masks[-1]
|
414 |
+
|
415 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
416 |
+
x = resnet(x, mask_mid, t)
|
417 |
+
x = rearrange(x, "b c t -> b t c")
|
418 |
+
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
|
419 |
+
for transformer_block in transformer_blocks:
|
420 |
+
x = transformer_block(
|
421 |
+
hidden_states=x,
|
422 |
+
attention_mask=mask_mid,
|
423 |
+
timestep=t,
|
424 |
+
)
|
425 |
+
x = rearrange(x, "b t c -> b c t")
|
426 |
+
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
|
427 |
+
|
428 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
429 |
+
mask_up = masks.pop()
|
430 |
+
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
|
431 |
+
x = rearrange(x, "b c t -> b t c")
|
432 |
+
mask_up = rearrange(mask_up, "b 1 t -> b t")
|
433 |
+
for transformer_block in transformer_blocks:
|
434 |
+
x = transformer_block(
|
435 |
+
hidden_states=x,
|
436 |
+
attention_mask=mask_up,
|
437 |
+
timestep=t,
|
438 |
+
)
|
439 |
+
x = rearrange(x, "b t c -> b c t")
|
440 |
+
mask_up = rearrange(mask_up, "b t -> b 1 t")
|
441 |
+
x = upsample(x * mask_up)
|
442 |
+
|
443 |
+
x = self.final_block(x, mask_up)
|
444 |
+
output = self.final_proj(x * mask_up)
|
445 |
+
|
446 |
+
return output * mask
|
447 |
+
|
448 |
+
|
449 |
+
class ConditionalDecoder(nn.Module):
|
450 |
+
def __init__(
|
451 |
+
self,
|
452 |
+
in_channels,
|
453 |
+
out_channels,
|
454 |
+
channels=(256, 256),
|
455 |
+
dropout=0.05,
|
456 |
+
attention_head_dim=64,
|
457 |
+
n_blocks=1,
|
458 |
+
num_mid_blocks=2,
|
459 |
+
num_heads=4,
|
460 |
+
act_fn="snake",
|
461 |
+
):
|
462 |
+
"""
|
463 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
464 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
465 |
+
"""
|
466 |
+
super().__init__()
|
467 |
+
channels = tuple(channels)
|
468 |
+
self.in_channels = in_channels
|
469 |
+
self.out_channels = out_channels
|
470 |
+
|
471 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
472 |
+
time_embed_dim = channels[0] * 4
|
473 |
+
self.time_mlp = TimestepEmbedding(
|
474 |
+
in_channels=in_channels,
|
475 |
+
time_embed_dim=time_embed_dim,
|
476 |
+
act_fn="silu",
|
477 |
+
)
|
478 |
+
self.down_blocks = nn.ModuleList([])
|
479 |
+
self.mid_blocks = nn.ModuleList([])
|
480 |
+
self.up_blocks = nn.ModuleList([])
|
481 |
+
|
482 |
+
output_channel = in_channels
|
483 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
484 |
+
input_channel = output_channel
|
485 |
+
output_channel = channels[i]
|
486 |
+
is_last = i == len(channels) - 1
|
487 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
488 |
+
transformer_blocks = nn.ModuleList(
|
489 |
+
[
|
490 |
+
BasicTransformerBlock(
|
491 |
+
dim=output_channel,
|
492 |
+
num_attention_heads=num_heads,
|
493 |
+
attention_head_dim=attention_head_dim,
|
494 |
+
dropout=dropout,
|
495 |
+
activation_fn=act_fn,
|
496 |
+
)
|
497 |
+
for _ in range(n_blocks)
|
498 |
+
]
|
499 |
+
)
|
500 |
+
downsample = (
|
501 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
502 |
+
)
|
503 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
504 |
+
|
505 |
+
for _ in range(num_mid_blocks):
|
506 |
+
input_channel = channels[-1]
|
507 |
+
out_channels = channels[-1]
|
508 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
509 |
+
|
510 |
+
transformer_blocks = nn.ModuleList(
|
511 |
+
[
|
512 |
+
BasicTransformerBlock(
|
513 |
+
dim=output_channel,
|
514 |
+
num_attention_heads=num_heads,
|
515 |
+
attention_head_dim=attention_head_dim,
|
516 |
+
dropout=dropout,
|
517 |
+
activation_fn=act_fn,
|
518 |
+
)
|
519 |
+
for _ in range(n_blocks)
|
520 |
+
]
|
521 |
+
)
|
522 |
+
|
523 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
524 |
+
|
525 |
+
channels = channels[::-1] + (channels[0],)
|
526 |
+
for i in range(len(channels) - 1):
|
527 |
+
input_channel = channels[i] * 2
|
528 |
+
output_channel = channels[i + 1]
|
529 |
+
is_last = i == len(channels) - 2
|
530 |
+
resnet = ResnetBlock1D(
|
531 |
+
dim=input_channel,
|
532 |
+
dim_out=output_channel,
|
533 |
+
time_emb_dim=time_embed_dim,
|
534 |
+
)
|
535 |
+
transformer_blocks = nn.ModuleList(
|
536 |
+
[
|
537 |
+
BasicTransformerBlock(
|
538 |
+
dim=output_channel,
|
539 |
+
num_attention_heads=num_heads,
|
540 |
+
attention_head_dim=attention_head_dim,
|
541 |
+
dropout=dropout,
|
542 |
+
activation_fn=act_fn,
|
543 |
+
)
|
544 |
+
for _ in range(n_blocks)
|
545 |
+
]
|
546 |
+
)
|
547 |
+
upsample = (
|
548 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
549 |
+
if not is_last
|
550 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
551 |
+
)
|
552 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
553 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
554 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
555 |
+
self.initialize_weights()
|
556 |
+
|
557 |
+
def initialize_weights(self):
|
558 |
+
for m in self.modules():
|
559 |
+
if isinstance(m, nn.Conv1d):
|
560 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
561 |
+
if m.bias is not None:
|
562 |
+
nn.init.constant_(m.bias, 0)
|
563 |
+
elif isinstance(m, nn.GroupNorm):
|
564 |
+
nn.init.constant_(m.weight, 1)
|
565 |
+
nn.init.constant_(m.bias, 0)
|
566 |
+
elif isinstance(m, nn.Linear):
|
567 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
568 |
+
if m.bias is not None:
|
569 |
+
nn.init.constant_(m.bias, 0)
|
570 |
+
|
571 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
572 |
+
"""Forward pass of the UNet1DConditional model.
|
573 |
+
|
574 |
+
Args:
|
575 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
576 |
+
mask (_type_): shape (batch_size, 1, time)
|
577 |
+
t (_type_): shape (batch_size)
|
578 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
579 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
580 |
+
|
581 |
+
Raises:
|
582 |
+
ValueError: _description_
|
583 |
+
ValueError: _description_
|
584 |
+
|
585 |
+
Returns:
|
586 |
+
_type_: _description_
|
587 |
+
"""
|
588 |
+
|
589 |
+
t = self.time_embeddings(t).to(t.dtype)
|
590 |
+
t = self.time_mlp(t)
|
591 |
+
|
592 |
+
x = pack([x, mu], "b * t")[0]
|
593 |
+
|
594 |
+
if spks is not None:
|
595 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
596 |
+
x = pack([x, spks], "b * t")[0]
|
597 |
+
if cond is not None:
|
598 |
+
x = pack([x, cond], "b * t")[0]
|
599 |
+
|
600 |
+
hiddens = []
|
601 |
+
masks = [mask]
|
602 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
603 |
+
mask_down = masks[-1]
|
604 |
+
x = resnet(x, mask_down, t)
|
605 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
606 |
+
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
607 |
+
for transformer_block in transformer_blocks:
|
608 |
+
x = transformer_block(
|
609 |
+
hidden_states=x,
|
610 |
+
attention_mask=attn_mask,
|
611 |
+
timestep=t,
|
612 |
+
)
|
613 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
614 |
+
hiddens.append(x) # Save hidden states for skip connections
|
615 |
+
x = downsample(x * mask_down)
|
616 |
+
masks.append(mask_down[:, :, ::2])
|
617 |
+
masks = masks[:-1]
|
618 |
+
mask_mid = masks[-1]
|
619 |
+
|
620 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
621 |
+
x = resnet(x, mask_mid, t)
|
622 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
623 |
+
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
624 |
+
for transformer_block in transformer_blocks:
|
625 |
+
x = transformer_block(
|
626 |
+
hidden_states=x,
|
627 |
+
attention_mask=attn_mask,
|
628 |
+
timestep=t,
|
629 |
+
)
|
630 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
631 |
+
|
632 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
633 |
+
mask_up = masks.pop()
|
634 |
+
skip = hiddens.pop()
|
635 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
636 |
+
x = resnet(x, mask_up, t)
|
637 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
638 |
+
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
639 |
+
for transformer_block in transformer_blocks:
|
640 |
+
x = transformer_block(
|
641 |
+
hidden_states=x,
|
642 |
+
attention_mask=attn_mask,
|
643 |
+
timestep=t,
|
644 |
+
)
|
645 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
646 |
+
x = upsample(x * mask_up)
|
647 |
+
x = self.final_block(x, mask_up)
|
648 |
+
output = self.final_proj(x * mask_up)
|
649 |
+
return output * mask
|
VietTTS/flow/flow.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from typing import Dict, Optional
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from omegaconf import DictConfig
|
8 |
+
from VietTTS.utils.mask import make_pad_mask
|
9 |
+
|
10 |
+
|
11 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
12 |
+
def __init__(self,
|
13 |
+
input_size: int = 512,
|
14 |
+
output_size: int = 80,
|
15 |
+
spk_embed_dim: int = 192,
|
16 |
+
output_type: str = "mel",
|
17 |
+
vocab_size: int = 4096,
|
18 |
+
input_frame_rate: int = 50,
|
19 |
+
only_mask_loss: bool = True,
|
20 |
+
encoder: torch.nn.Module = None,
|
21 |
+
length_regulator: torch.nn.Module = None,
|
22 |
+
decoder: torch.nn.Module = None,
|
23 |
+
decoder_conf: Dict = {
|
24 |
+
'in_channels': 240,
|
25 |
+
'out_channel': 80,
|
26 |
+
'spk_emb_dim': 80,
|
27 |
+
'n_spks': 1,
|
28 |
+
'cfm_params': DictConfig({
|
29 |
+
'sigma_min': 1e-06,
|
30 |
+
'solver': 'euler',
|
31 |
+
't_scheduler': 'cosine',
|
32 |
+
'training_cfg_rate': 0.2,
|
33 |
+
'inference_cfg_rate': 0.7,
|
34 |
+
'reg_loss_type': 'l1'
|
35 |
+
}),
|
36 |
+
'decoder_params': {
|
37 |
+
'channels': [256, 256],
|
38 |
+
'dropout': 0.0,
|
39 |
+
'attention_head_dim': 64,
|
40 |
+
'n_blocks': 4,
|
41 |
+
'num_mid_blocks': 12,
|
42 |
+
'num_heads': 8,
|
43 |
+
'act_fn': 'gelu'
|
44 |
+
}
|
45 |
+
},
|
46 |
+
mel_feat_conf: Dict = {
|
47 |
+
'n_fft': 1024,
|
48 |
+
'num_mels': 80,
|
49 |
+
'sampling_rate': 22050,
|
50 |
+
'hop_size': 256,
|
51 |
+
'win_size': 1024,
|
52 |
+
'fmin': 0,
|
53 |
+
'fmax': 8000
|
54 |
+
}
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
self.input_size = input_size
|
58 |
+
self.output_size = output_size
|
59 |
+
self.decoder_conf = decoder_conf
|
60 |
+
self.mel_feat_conf = mel_feat_conf
|
61 |
+
self.vocab_size = vocab_size
|
62 |
+
self.output_type = output_type
|
63 |
+
self.input_frame_rate = input_frame_rate
|
64 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
65 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
66 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
67 |
+
self.encoder = encoder
|
68 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
69 |
+
self.decoder = decoder
|
70 |
+
self.length_regulator = length_regulator
|
71 |
+
self.only_mask_loss = only_mask_loss
|
72 |
+
|
73 |
+
def forward(
|
74 |
+
self,
|
75 |
+
batch: dict,
|
76 |
+
device: torch.device,
|
77 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
78 |
+
token = batch['speech_token'].to(device)
|
79 |
+
token_len = batch['speech_token_len'].to(device)
|
80 |
+
feat = batch['speech_feat'].to(device)
|
81 |
+
feat_len = batch['speech_feat_len'].to(device)
|
82 |
+
embedding = batch['embedding'].to(device)
|
83 |
+
|
84 |
+
# xvec projection
|
85 |
+
embedding = F.normalize(embedding, dim=1)
|
86 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
87 |
+
|
88 |
+
# concat text and prompt_text
|
89 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
90 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
91 |
+
|
92 |
+
# text encode
|
93 |
+
h, h_lengths = self.encoder(token, token_len)
|
94 |
+
h = self.encoder_proj(h)
|
95 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
96 |
+
|
97 |
+
# get conditions
|
98 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
99 |
+
for i, j in enumerate(feat_len):
|
100 |
+
if random.random() < 0.5:
|
101 |
+
continue
|
102 |
+
index = random.randint(0, int(0.3 * j))
|
103 |
+
conds[i, :index] = feat[i, :index]
|
104 |
+
conds = conds.transpose(1, 2)
|
105 |
+
|
106 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
107 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
108 |
+
loss, _ = self.decoder.compute_loss(
|
109 |
+
feat.transpose(1, 2).contiguous(),
|
110 |
+
mask.unsqueeze(1),
|
111 |
+
h.transpose(1, 2).contiguous(),
|
112 |
+
embedding,
|
113 |
+
cond=conds
|
114 |
+
)
|
115 |
+
return {'loss': loss}
|
116 |
+
|
117 |
+
@torch.inference_mode()
|
118 |
+
def inference(self,
|
119 |
+
token,
|
120 |
+
token_len,
|
121 |
+
prompt_token,
|
122 |
+
prompt_token_len,
|
123 |
+
prompt_feat,
|
124 |
+
prompt_feat_len,
|
125 |
+
embedding):
|
126 |
+
assert token.shape[0] == 1
|
127 |
+
# xvec projection
|
128 |
+
embedding = F.normalize(embedding, dim=1)
|
129 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
130 |
+
|
131 |
+
# concat text and prompt_text
|
132 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
133 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
134 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
135 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
136 |
+
|
137 |
+
# text encode
|
138 |
+
h, h_lengths = self.encoder(token, token_len)
|
139 |
+
h = self.encoder_proj(h)
|
140 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
141 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
142 |
+
|
143 |
+
# get conditions
|
144 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
145 |
+
conds[:, :mel_len1] = prompt_feat
|
146 |
+
conds = conds.transpose(1, 2)
|
147 |
+
|
148 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
149 |
+
feat = self.decoder(
|
150 |
+
mu=h.transpose(1, 2).contiguous(),
|
151 |
+
mask=mask.unsqueeze(1),
|
152 |
+
spks=embedding,
|
153 |
+
cond=conds,
|
154 |
+
n_timesteps=10
|
155 |
+
)
|
156 |
+
feat = feat[:, :, mel_len1:]
|
157 |
+
assert feat.shape[2] == mel_len2
|
158 |
+
return feat
|
VietTTS/flow/flow_matching.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from VietTTS.flow.decoder import Decoder
|
7 |
+
|
8 |
+
|
9 |
+
class BASECFM(torch.nn.Module, ABC):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
n_feats,
|
13 |
+
cfm_params,
|
14 |
+
n_spks=1,
|
15 |
+
spk_emb_dim=128,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.n_feats = n_feats
|
19 |
+
self.n_spks = n_spks
|
20 |
+
self.spk_emb_dim = spk_emb_dim
|
21 |
+
self.solver = cfm_params.solver
|
22 |
+
if hasattr(cfm_params, "sigma_min"):
|
23 |
+
self.sigma_min = cfm_params.sigma_min
|
24 |
+
else:
|
25 |
+
self.sigma_min = 1e-4
|
26 |
+
|
27 |
+
self.estimator = None
|
28 |
+
|
29 |
+
@torch.inference_mode()
|
30 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
31 |
+
"""Forward diffusion
|
32 |
+
|
33 |
+
Args:
|
34 |
+
mu (torch.Tensor): output of encoder
|
35 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
36 |
+
mask (torch.Tensor): output_mask
|
37 |
+
shape: (batch_size, 1, mel_timesteps)
|
38 |
+
n_timesteps (int): number of diffusion steps
|
39 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
40 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
41 |
+
shape: (batch_size, spk_emb_dim)
|
42 |
+
cond: Not used but kept for future purposes
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
sample: generated mel-spectrogram
|
46 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
47 |
+
"""
|
48 |
+
z = torch.randn_like(mu) * temperature
|
49 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
50 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
51 |
+
|
52 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
53 |
+
"""
|
54 |
+
Fixed euler solver for ODEs.
|
55 |
+
Args:
|
56 |
+
x (torch.Tensor): random noise
|
57 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
58 |
+
shape: (n_timesteps + 1,)
|
59 |
+
mu (torch.Tensor): output of encoder
|
60 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
61 |
+
mask (torch.Tensor): output_mask
|
62 |
+
shape: (batch_size, 1, mel_timesteps)
|
63 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
64 |
+
shape: (batch_size, spk_emb_dim)
|
65 |
+
cond: Not used but kept for future purposes
|
66 |
+
"""
|
67 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
68 |
+
|
69 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
70 |
+
# Or in future might add like a return_all_steps flag
|
71 |
+
sol = []
|
72 |
+
|
73 |
+
for step in range(1, len(t_span)):
|
74 |
+
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
75 |
+
|
76 |
+
x = x + dt * dphi_dt
|
77 |
+
t = t + dt
|
78 |
+
sol.append(x)
|
79 |
+
if step < len(t_span) - 1:
|
80 |
+
dt = t_span[step + 1] - t
|
81 |
+
|
82 |
+
return sol[-1]
|
83 |
+
|
84 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
85 |
+
"""Computes diffusion loss
|
86 |
+
|
87 |
+
Args:
|
88 |
+
x1 (torch.Tensor): Target
|
89 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
90 |
+
mask (torch.Tensor): target mask
|
91 |
+
shape: (batch_size, 1, mel_timesteps)
|
92 |
+
mu (torch.Tensor): output of encoder
|
93 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
94 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
95 |
+
shape: (batch_size, spk_emb_dim)
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
loss: conditional flow matching loss
|
99 |
+
y: conditional flow
|
100 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
101 |
+
"""
|
102 |
+
b, _, t = mu.shape
|
103 |
+
|
104 |
+
# random timestep
|
105 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
106 |
+
# sample noise p(x_0)
|
107 |
+
z = torch.randn_like(x1)
|
108 |
+
|
109 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
110 |
+
u = x1 - (1 - self.sigma_min) * z
|
111 |
+
|
112 |
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
113 |
+
torch.sum(mask) * u.shape[1]
|
114 |
+
)
|
115 |
+
return loss, y
|
116 |
+
|
117 |
+
|
118 |
+
class CFM(BASECFM):
|
119 |
+
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
|
120 |
+
super().__init__(
|
121 |
+
n_feats=in_channels,
|
122 |
+
cfm_params=cfm_params,
|
123 |
+
n_spks=n_spks,
|
124 |
+
spk_emb_dim=spk_emb_dim,
|
125 |
+
)
|
126 |
+
|
127 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
128 |
+
# Just change the architecture of the estimator here
|
129 |
+
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
130 |
+
|
131 |
+
|
132 |
+
class ConditionalCFM(BASECFM):
|
133 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
134 |
+
super().__init__(
|
135 |
+
n_feats=in_channels,
|
136 |
+
cfm_params=cfm_params,
|
137 |
+
n_spks=n_spks,
|
138 |
+
spk_emb_dim=spk_emb_dim,
|
139 |
+
)
|
140 |
+
self.t_scheduler = cfm_params.t_scheduler
|
141 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
142 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
143 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
144 |
+
# Just change the architecture of the estimator here
|
145 |
+
self.estimator = estimator
|
146 |
+
|
147 |
+
@torch.inference_mode()
|
148 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
149 |
+
"""Forward diffusion
|
150 |
+
|
151 |
+
Args:
|
152 |
+
mu (torch.Tensor): output of encoder
|
153 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
154 |
+
mask (torch.Tensor): output_mask
|
155 |
+
shape: (batch_size, 1, mel_timesteps)
|
156 |
+
n_timesteps (int): number of diffusion steps
|
157 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
158 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
159 |
+
shape: (batch_size, spk_emb_dim)
|
160 |
+
cond: Not used but kept for future purposes
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
sample: generated mel-spectrogram
|
164 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
165 |
+
"""
|
166 |
+
z = torch.randn_like(mu) * temperature
|
167 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
168 |
+
if self.t_scheduler == 'cosine':
|
169 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
170 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
171 |
+
|
172 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
173 |
+
"""
|
174 |
+
Fixed euler solver for ODEs.
|
175 |
+
Args:
|
176 |
+
x (torch.Tensor): random noise
|
177 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
178 |
+
shape: (n_timesteps + 1,)
|
179 |
+
mu (torch.Tensor): output of encoder
|
180 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
181 |
+
mask (torch.Tensor): output_mask
|
182 |
+
shape: (batch_size, 1, mel_timesteps)
|
183 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
184 |
+
shape: (batch_size, spk_emb_dim)
|
185 |
+
cond: Not used but kept for future purposes
|
186 |
+
"""
|
187 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
188 |
+
t = t.unsqueeze(dim=0)
|
189 |
+
|
190 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
191 |
+
# Or in future might add like a return_all_steps flag
|
192 |
+
sol = []
|
193 |
+
|
194 |
+
for step in range(1, len(t_span)):
|
195 |
+
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
|
196 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
197 |
+
if self.inference_cfg_rate > 0:
|
198 |
+
cfg_dphi_dt = self.forward_estimator(
|
199 |
+
x, mask,
|
200 |
+
torch.zeros_like(mu), t,
|
201 |
+
torch.zeros_like(spks) if spks is not None else None,
|
202 |
+
torch.zeros_like(cond)
|
203 |
+
)
|
204 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
205 |
+
self.inference_cfg_rate * cfg_dphi_dt)
|
206 |
+
x = x + dt * dphi_dt
|
207 |
+
t = t + dt
|
208 |
+
sol.append(x)
|
209 |
+
if step < len(t_span) - 1:
|
210 |
+
dt = t_span[step + 1] - t
|
211 |
+
|
212 |
+
return sol[-1]
|
213 |
+
|
214 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
215 |
+
if isinstance(self.estimator, torch.nn.Module):
|
216 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
217 |
+
else:
|
218 |
+
ort_inputs = {
|
219 |
+
'x': x.cpu().numpy(),
|
220 |
+
'mask': mask.cpu().numpy(),
|
221 |
+
'mu': mu.cpu().numpy(),
|
222 |
+
't': t.cpu().numpy(),
|
223 |
+
'spks': spks.cpu().numpy(),
|
224 |
+
'cond': cond.cpu().numpy()
|
225 |
+
}
|
226 |
+
output = self.estimator.run(None, ort_inputs)[0]
|
227 |
+
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
228 |
+
|
229 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
230 |
+
"""Computes diffusion loss
|
231 |
+
|
232 |
+
Args:
|
233 |
+
x1 (torch.Tensor): Target
|
234 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
235 |
+
mask (torch.Tensor): target mask
|
236 |
+
shape: (batch_size, 1, mel_timesteps)
|
237 |
+
mu (torch.Tensor): output of encoder
|
238 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
239 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
240 |
+
shape: (batch_size, spk_emb_dim)
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
loss: conditional flow matching loss
|
244 |
+
y: conditional flow
|
245 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
246 |
+
"""
|
247 |
+
b, _, t = mu.shape
|
248 |
+
|
249 |
+
# random timestep
|
250 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
251 |
+
if self.t_scheduler == 'cosine':
|
252 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
253 |
+
# sample noise p(x_0)
|
254 |
+
z = torch.randn_like(x1)
|
255 |
+
|
256 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
257 |
+
u = x1 - (1 - self.sigma_min) * z
|
258 |
+
|
259 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
260 |
+
if self.training_cfg_rate > 0:
|
261 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
262 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
263 |
+
spks = spks * cfg_mask.view(-1, 1)
|
264 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
265 |
+
|
266 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
267 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
268 |
+
return loss, y
|
VietTTS/flow/length_regulator.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from VietTTS.utils.mask import make_pad_mask
|
6 |
+
|
7 |
+
|
8 |
+
class InterpolateRegulator(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
channels: int,
|
12 |
+
sampling_ratios: Tuple,
|
13 |
+
out_channels: int = None,
|
14 |
+
groups: int = 1,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.sampling_ratios = sampling_ratios
|
18 |
+
out_channels = out_channels or channels
|
19 |
+
model = nn.ModuleList([])
|
20 |
+
if len(sampling_ratios) > 0:
|
21 |
+
for _ in sampling_ratios:
|
22 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
23 |
+
norm = nn.GroupNorm(groups, channels)
|
24 |
+
act = nn.Mish()
|
25 |
+
model.extend([module, norm, act])
|
26 |
+
model.append(
|
27 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
28 |
+
)
|
29 |
+
self.model = nn.Sequential(*model)
|
30 |
+
|
31 |
+
def forward(self, x, ylens=None):
|
32 |
+
# x in (B, T, D)
|
33 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
34 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
35 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
36 |
+
olens = ylens
|
37 |
+
return out * mask, olens
|
38 |
+
|
39 |
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
40 |
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
41 |
+
# x in (B, T, D)
|
42 |
+
if x2.shape[1] > 40:
|
43 |
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
44 |
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
45 |
+
mode='linear')
|
46 |
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
47 |
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
48 |
+
else:
|
49 |
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
50 |
+
if x1.shape[1] != 0:
|
51 |
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
52 |
+
x = torch.concat([x1, x2], dim=2)
|
53 |
+
else:
|
54 |
+
x = x2
|
55 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
56 |
+
return out, mel_len1 + mel_len2
|
VietTTS/frontend.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import whisper
|
5 |
+
import onnxruntime
|
6 |
+
import numpy as np
|
7 |
+
import torchaudio.compliance.kaldi as kaldi
|
8 |
+
from typing import Callable, List, Union
|
9 |
+
from functools import partial
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
from VietTTS.utils.frontend_utils import split_text, normalize_text, mel_spectrogram
|
13 |
+
from VietTTS.tokenizer.tokenizer import get_tokenizer
|
14 |
+
|
15 |
+
class TTSFrontEnd:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
speech_embedding_model: str,
|
19 |
+
speech_tokenizer_model: str,
|
20 |
+
):
|
21 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
+
self.tokenizer = get_tokenizer()
|
23 |
+
option = onnxruntime.SessionOptions()
|
24 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
25 |
+
option.intra_op_num_threads = 1
|
26 |
+
self.speech_embedding_session = onnxruntime.InferenceSession(
|
27 |
+
speech_embedding_model,
|
28 |
+
sess_options=option,
|
29 |
+
providers=["CPUExecutionProvider"]
|
30 |
+
)
|
31 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(
|
32 |
+
speech_tokenizer_model,
|
33 |
+
sess_options=option,
|
34 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"]
|
35 |
+
)
|
36 |
+
self.spk2info = {}
|
37 |
+
|
38 |
+
def _extract_text_token(self, text: str):
|
39 |
+
text_token = self.tokenizer.encode(text, allowed_special='all')
|
40 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
41 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
42 |
+
return text_token, text_token_len
|
43 |
+
|
44 |
+
def _extract_speech_token(self, speech: torch.Tensor):
|
45 |
+
if speech.shape[1] / 16000 > 30:
|
46 |
+
speech = speech[:, :int(16000 * 30)]
|
47 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
48 |
+
speech_token = self.speech_tokenizer_session.run(
|
49 |
+
None,
|
50 |
+
{self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
|
51 |
+
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)}
|
52 |
+
)[0].flatten().tolist()
|
53 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
54 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
55 |
+
return speech_token, speech_token_len
|
56 |
+
|
57 |
+
def _extract_spk_embedding(self, speech: torch.Tensor):
|
58 |
+
feat = kaldi.fbank(
|
59 |
+
waveform=speech,
|
60 |
+
num_mel_bins=80,
|
61 |
+
dither=0,
|
62 |
+
sample_frequency=16000
|
63 |
+
)
|
64 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
65 |
+
embedding = self.speech_embedding_session.run(
|
66 |
+
None,
|
67 |
+
{self.speech_embedding_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}
|
68 |
+
)[0].flatten().tolist()
|
69 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
70 |
+
return embedding
|
71 |
+
|
72 |
+
def _extract_speech_feat(self, speech: torch.Tensor):
|
73 |
+
speech_feat = mel_spectrogram(
|
74 |
+
y=speech,
|
75 |
+
n_fft=1024,
|
76 |
+
num_mels=80,
|
77 |
+
sampling_rate=22050,
|
78 |
+
hop_size=256,
|
79 |
+
win_size=1024,
|
80 |
+
fmin=0,
|
81 |
+
fmax=8000,
|
82 |
+
center=False
|
83 |
+
).squeeze(dim=0).transpose(0, 1).to(self.device)
|
84 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
85 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
86 |
+
return speech_feat, speech_feat_len
|
87 |
+
|
88 |
+
def preprocess_text(self, text, split=True) -> Union[str, List[str]]:
|
89 |
+
text = normalize_text(text)
|
90 |
+
if split:
|
91 |
+
text = list(split_text(
|
92 |
+
text=text,
|
93 |
+
tokenize=partial(self.tokenizer.encode, allowed_special='all'),
|
94 |
+
token_max_n=30,
|
95 |
+
token_min_n=10,
|
96 |
+
merge_len=5,
|
97 |
+
comma_split=False
|
98 |
+
))
|
99 |
+
return text
|
100 |
+
|
101 |
+
def frontend_tts(
|
102 |
+
self,
|
103 |
+
text: str,
|
104 |
+
prompt_speech_16k: Union[np.ndarray, torch.Tensor]
|
105 |
+
) -> dict:
|
106 |
+
if isinstance(prompt_speech_16k, np.ndarray):
|
107 |
+
prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
|
108 |
+
|
109 |
+
text_token, text_token_len = self._extract_text_token(text)
|
110 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
111 |
+
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
112 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
113 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
114 |
+
|
115 |
+
model_input = {
|
116 |
+
'text': text_token,
|
117 |
+
'text_len': text_token_len,
|
118 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
119 |
+
'prompt_speech_feat': speech_feat,
|
120 |
+
'prompt_speech_feat_len': speech_feat_len,
|
121 |
+
'llm_embedding': embedding,
|
122 |
+
'flow_embedding': embedding
|
123 |
+
}
|
124 |
+
return model_input
|
125 |
+
|
126 |
+
|
127 |
+
def frontend_vc(
|
128 |
+
self,
|
129 |
+
source_speech_16k: Union[np.ndarray, torch.Tensor],
|
130 |
+
prompt_speech_16k: Union[np.ndarray, torch.Tensor]
|
131 |
+
) -> dict:
|
132 |
+
if isinstance(source_speech_16k, np.ndarray):
|
133 |
+
source_speech_16k = torch.from_numpy(source_speech_16k)
|
134 |
+
if isinstance(prompt_speech_16k, np.ndarray):
|
135 |
+
prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
|
136 |
+
|
137 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
138 |
+
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
139 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
140 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
141 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
142 |
+
model_input = {
|
143 |
+
'source_speech_token': source_speech_token,
|
144 |
+
'source_speech_token_len': source_speech_token_len,
|
145 |
+
'flow_prompt_speech_token': prompt_speech_token,
|
146 |
+
'flow_prompt_speech_token_len': prompt_speech_token_len,
|
147 |
+
'prompt_speech_feat': prompt_speech_feat,
|
148 |
+
'prompt_speech_feat_len': prompt_speech_feat_len,
|
149 |
+
'flow_embedding': embedding
|
150 |
+
}
|
151 |
+
return model_input
|
VietTTS/hifigan/f0_predictor.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils import weight_norm
|
4 |
+
|
5 |
+
|
6 |
+
class ConvRNNF0Predictor(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
num_class: int = 1,
|
9 |
+
in_channels: int = 80,
|
10 |
+
cond_channels: int = 512
|
11 |
+
):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.num_class = num_class
|
15 |
+
self.condnet = nn.Sequential(
|
16 |
+
weight_norm(
|
17 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
18 |
+
),
|
19 |
+
nn.ELU(),
|
20 |
+
weight_norm(
|
21 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
22 |
+
),
|
23 |
+
nn.ELU(),
|
24 |
+
weight_norm(
|
25 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
26 |
+
),
|
27 |
+
nn.ELU(),
|
28 |
+
weight_norm(
|
29 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
30 |
+
),
|
31 |
+
nn.ELU(),
|
32 |
+
weight_norm(
|
33 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
34 |
+
),
|
35 |
+
nn.ELU(),
|
36 |
+
)
|
37 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
x = self.condnet(x)
|
41 |
+
x = x.transpose(1, 2)
|
42 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
VietTTS/hifigan/generator.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""HIFI-GAN"""
|
2 |
+
|
3 |
+
import typing as tp
|
4 |
+
import numpy as np
|
5 |
+
from scipy.signal import get_window
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.nn import Conv1d
|
10 |
+
from torch.nn import ConvTranspose1d
|
11 |
+
from torch.nn.utils import remove_weight_norm
|
12 |
+
from torch.nn.utils import weight_norm
|
13 |
+
from torch.distributions.uniform import Uniform
|
14 |
+
|
15 |
+
from VietTTS.transformer.activation import Snake
|
16 |
+
from VietTTS.utils.common import get_padding
|
17 |
+
from VietTTS.utils.common import init_weights
|
18 |
+
|
19 |
+
|
20 |
+
"""hifigan based generator implementation.
|
21 |
+
|
22 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
23 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
24 |
+
https://github.com/NVIDIA/BigVGAN
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
class ResBlock(torch.nn.Module):
|
30 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
channels: int = 512,
|
34 |
+
kernel_size: int = 3,
|
35 |
+
dilations: tp.List[int] = [1, 3, 5],
|
36 |
+
):
|
37 |
+
super(ResBlock, self).__init__()
|
38 |
+
self.convs1 = nn.ModuleList()
|
39 |
+
self.convs2 = nn.ModuleList()
|
40 |
+
|
41 |
+
for dilation in dilations:
|
42 |
+
self.convs1.append(
|
43 |
+
weight_norm(
|
44 |
+
Conv1d(
|
45 |
+
channels,
|
46 |
+
channels,
|
47 |
+
kernel_size,
|
48 |
+
1,
|
49 |
+
dilation=dilation,
|
50 |
+
padding=get_padding(kernel_size, dilation)
|
51 |
+
)
|
52 |
+
)
|
53 |
+
)
|
54 |
+
self.convs2.append(
|
55 |
+
weight_norm(
|
56 |
+
Conv1d(
|
57 |
+
channels,
|
58 |
+
channels,
|
59 |
+
kernel_size,
|
60 |
+
1,
|
61 |
+
dilation=1,
|
62 |
+
padding=get_padding(kernel_size, 1)
|
63 |
+
)
|
64 |
+
)
|
65 |
+
)
|
66 |
+
self.convs1.apply(init_weights)
|
67 |
+
self.convs2.apply(init_weights)
|
68 |
+
self.activations1 = nn.ModuleList([
|
69 |
+
Snake(channels, alpha_logscale=False)
|
70 |
+
for _ in range(len(self.convs1))
|
71 |
+
])
|
72 |
+
self.activations2 = nn.ModuleList([
|
73 |
+
Snake(channels, alpha_logscale=False)
|
74 |
+
for _ in range(len(self.convs2))
|
75 |
+
])
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
for idx in range(len(self.convs1)):
|
79 |
+
xt = self.activations1[idx](x)
|
80 |
+
xt = self.convs1[idx](xt)
|
81 |
+
xt = self.activations2[idx](xt)
|
82 |
+
xt = self.convs2[idx](xt)
|
83 |
+
x = xt + x
|
84 |
+
return x
|
85 |
+
|
86 |
+
def remove_weight_norm(self):
|
87 |
+
for idx in range(len(self.convs1)):
|
88 |
+
remove_weight_norm(self.convs1[idx])
|
89 |
+
remove_weight_norm(self.convs2[idx])
|
90 |
+
|
91 |
+
|
92 |
+
class SineGen(torch.nn.Module):
|
93 |
+
""" Definition of sine generator
|
94 |
+
SineGen(samp_rate, harmonic_num = 0,
|
95 |
+
sine_amp = 0.1, noise_std = 0.003,
|
96 |
+
voiced_threshold = 0,
|
97 |
+
flag_for_pulse=False)
|
98 |
+
samp_rate: sampling rate in Hz
|
99 |
+
harmonic_num: number of harmonic overtones (default 0)
|
100 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
101 |
+
noise_std: std of Gaussian noise (default 0.003)
|
102 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
103 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
104 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
105 |
+
segment is always sin(np.pi) or cos(0)
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
109 |
+
sine_amp=0.1, noise_std=0.003,
|
110 |
+
voiced_threshold=0):
|
111 |
+
super(SineGen, self).__init__()
|
112 |
+
self.sine_amp = sine_amp
|
113 |
+
self.noise_std = noise_std
|
114 |
+
self.harmonic_num = harmonic_num
|
115 |
+
self.sampling_rate = samp_rate
|
116 |
+
self.voiced_threshold = voiced_threshold
|
117 |
+
|
118 |
+
def _f02uv(self, f0):
|
119 |
+
# generate uv signal
|
120 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
121 |
+
return uv
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def forward(self, f0):
|
125 |
+
"""
|
126 |
+
:param f0: [B, 1, sample_len], Hz
|
127 |
+
:return: [B, 1, sample_len]
|
128 |
+
"""
|
129 |
+
|
130 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
131 |
+
for i in range(self.harmonic_num + 1):
|
132 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
133 |
+
|
134 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
135 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
136 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
137 |
+
phase_vec[:, 0, :] = 0
|
138 |
+
|
139 |
+
# generate sine waveforms
|
140 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
141 |
+
|
142 |
+
# generate uv signal
|
143 |
+
uv = self._f02uv(f0)
|
144 |
+
|
145 |
+
# noise: for unvoiced should be similar to sine_amp
|
146 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
147 |
+
# . for voiced regions is self.noise_std
|
148 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
149 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
150 |
+
|
151 |
+
# first: set the unvoiced part to 0 by uv
|
152 |
+
# then: additive noise
|
153 |
+
sine_waves = sine_waves * uv + noise
|
154 |
+
return sine_waves, uv, noise
|
155 |
+
|
156 |
+
|
157 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
158 |
+
""" SourceModule for hn-nsf
|
159 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
160 |
+
add_noise_std=0.003, voiced_threshod=0)
|
161 |
+
sampling_rate: sampling_rate in Hz
|
162 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
163 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
164 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
165 |
+
note that amplitude of noise in unvoiced is decided
|
166 |
+
by sine_amp
|
167 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
168 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
169 |
+
F0_sampled (batchsize, length, 1)
|
170 |
+
Sine_source (batchsize, length, 1)
|
171 |
+
noise_source (batchsize, length 1)
|
172 |
+
uv (batchsize, length, 1)
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
176 |
+
add_noise_std=0.003, voiced_threshod=0):
|
177 |
+
super(SourceModuleHnNSF, self).__init__()
|
178 |
+
|
179 |
+
self.sine_amp = sine_amp
|
180 |
+
self.noise_std = add_noise_std
|
181 |
+
|
182 |
+
# to produce sine waveforms
|
183 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
184 |
+
sine_amp, add_noise_std, voiced_threshod)
|
185 |
+
|
186 |
+
# to merge source harmonics into a single excitation
|
187 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
188 |
+
self.l_tanh = torch.nn.Tanh()
|
189 |
+
|
190 |
+
def forward(self, x):
|
191 |
+
"""
|
192 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
193 |
+
F0_sampled (batchsize, length, 1)
|
194 |
+
Sine_source (batchsize, length, 1)
|
195 |
+
noise_source (batchsize, length 1)
|
196 |
+
"""
|
197 |
+
# source for harmonic branch
|
198 |
+
with torch.no_grad():
|
199 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
200 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
201 |
+
uv = uv.transpose(1, 2)
|
202 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
203 |
+
|
204 |
+
# source for noise branch, in the same shape as uv
|
205 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
206 |
+
return sine_merge, noise, uv
|
207 |
+
|
208 |
+
|
209 |
+
class HiFTGenerator(nn.Module):
|
210 |
+
"""
|
211 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
212 |
+
https://arxiv.org/abs/2309.09493
|
213 |
+
"""
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
in_channels: int = 80,
|
217 |
+
base_channels: int = 512,
|
218 |
+
nb_harmonics: int = 8,
|
219 |
+
sampling_rate: int = 22050,
|
220 |
+
nsf_alpha: float = 0.1,
|
221 |
+
nsf_sigma: float = 0.003,
|
222 |
+
nsf_voiced_threshold: float = 10,
|
223 |
+
upsample_rates: tp.List[int] = [8, 8],
|
224 |
+
upsample_kernel_sizes: tp.List[int] = [16, 16],
|
225 |
+
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
226 |
+
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
227 |
+
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
228 |
+
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
|
229 |
+
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
|
230 |
+
lrelu_slope: float = 0.1,
|
231 |
+
audio_limit: float = 0.99,
|
232 |
+
f0_predictor: torch.nn.Module = None,
|
233 |
+
):
|
234 |
+
super(HiFTGenerator, self).__init__()
|
235 |
+
|
236 |
+
self.out_channels = 1
|
237 |
+
self.nb_harmonics = nb_harmonics
|
238 |
+
self.sampling_rate = sampling_rate
|
239 |
+
self.istft_params = istft_params
|
240 |
+
self.lrelu_slope = lrelu_slope
|
241 |
+
self.audio_limit = audio_limit
|
242 |
+
|
243 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
244 |
+
self.num_upsamples = len(upsample_rates)
|
245 |
+
self.m_source = SourceModuleHnNSF(
|
246 |
+
sampling_rate=sampling_rate,
|
247 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
248 |
+
harmonic_num=nb_harmonics,
|
249 |
+
sine_amp=nsf_alpha,
|
250 |
+
add_noise_std=nsf_sigma,
|
251 |
+
voiced_threshod=nsf_voiced_threshold)
|
252 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
253 |
+
|
254 |
+
self.conv_pre = weight_norm(
|
255 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
256 |
+
)
|
257 |
+
|
258 |
+
# Up
|
259 |
+
self.ups = nn.ModuleList()
|
260 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
261 |
+
self.ups.append(
|
262 |
+
weight_norm(
|
263 |
+
ConvTranspose1d(
|
264 |
+
base_channels // (2**i),
|
265 |
+
base_channels // (2**(i + 1)),
|
266 |
+
k,
|
267 |
+
u,
|
268 |
+
padding=(k - u) // 2,
|
269 |
+
)
|
270 |
+
)
|
271 |
+
)
|
272 |
+
|
273 |
+
# Down
|
274 |
+
self.source_downs = nn.ModuleList()
|
275 |
+
self.source_resblocks = nn.ModuleList()
|
276 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
277 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
278 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
279 |
+
if u == 1:
|
280 |
+
self.source_downs.append(
|
281 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
self.source_downs.append(
|
285 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
286 |
+
)
|
287 |
+
|
288 |
+
self.source_resblocks.append(
|
289 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
290 |
+
)
|
291 |
+
|
292 |
+
self.resblocks = nn.ModuleList()
|
293 |
+
for i in range(len(self.ups)):
|
294 |
+
ch = base_channels // (2**(i + 1))
|
295 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
296 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
297 |
+
|
298 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
299 |
+
self.ups.apply(init_weights)
|
300 |
+
self.conv_post.apply(init_weights)
|
301 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
302 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
303 |
+
self.f0_predictor = f0_predictor
|
304 |
+
|
305 |
+
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
306 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
307 |
+
|
308 |
+
har_source, _, _ = self.m_source(f0)
|
309 |
+
return har_source.transpose(1, 2)
|
310 |
+
|
311 |
+
def _stft(self, x):
|
312 |
+
spec = torch.stft(
|
313 |
+
x,
|
314 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
315 |
+
return_complex=True)
|
316 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
317 |
+
return spec[..., 0], spec[..., 1]
|
318 |
+
|
319 |
+
def _istft(self, magnitude, phase):
|
320 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
321 |
+
real = magnitude * torch.cos(phase)
|
322 |
+
img = magnitude * torch.sin(phase)
|
323 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
324 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
325 |
+
return inverse_transform
|
326 |
+
|
327 |
+
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
328 |
+
f0 = self.f0_predictor(x)
|
329 |
+
s = self._f02source(f0)
|
330 |
+
|
331 |
+
# use cache_source to avoid glitch
|
332 |
+
if cache_source.shape[2] != 0:
|
333 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
334 |
+
|
335 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
336 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
337 |
+
|
338 |
+
x = self.conv_pre(x)
|
339 |
+
for i in range(self.num_upsamples):
|
340 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
341 |
+
x = self.ups[i](x)
|
342 |
+
|
343 |
+
if i == self.num_upsamples - 1:
|
344 |
+
x = self.reflection_pad(x)
|
345 |
+
|
346 |
+
# fusion
|
347 |
+
si = self.source_downs[i](s_stft)
|
348 |
+
si = self.source_resblocks[i](si)
|
349 |
+
x = x + si
|
350 |
+
|
351 |
+
xs = None
|
352 |
+
for j in range(self.num_kernels):
|
353 |
+
if xs is None:
|
354 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
355 |
+
else:
|
356 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
357 |
+
x = xs / self.num_kernels
|
358 |
+
|
359 |
+
x = F.leaky_relu(x)
|
360 |
+
x = self.conv_post(x)
|
361 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
362 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
363 |
+
|
364 |
+
x = self._istft(magnitude, phase)
|
365 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
366 |
+
return x, s
|
367 |
+
|
368 |
+
def remove_weight_norm(self):
|
369 |
+
print('Removing weight norm...')
|
370 |
+
for l in self.ups:
|
371 |
+
remove_weight_norm(l)
|
372 |
+
for l in self.resblocks:
|
373 |
+
l.remove_weight_norm()
|
374 |
+
remove_weight_norm(self.conv_pre)
|
375 |
+
remove_weight_norm(self.conv_post)
|
376 |
+
self.source_module.remove_weight_norm()
|
377 |
+
for l in self.source_downs:
|
378 |
+
remove_weight_norm(l)
|
379 |
+
for l in self.source_resblocks:
|
380 |
+
l.remove_weight_norm()
|
381 |
+
|
382 |
+
@torch.inference_mode()
|
383 |
+
def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
384 |
+
return self.forward(x=mel, cache_source=cache_source)
|
VietTTS/llm/llm.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Callable, List, Generator
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
6 |
+
from VietTTS.utils.common import IGNORE_ID
|
7 |
+
from VietTTS.transformer.label_smoothing_loss import LabelSmoothingLoss
|
8 |
+
from VietTTS.utils.common import th_accuracy
|
9 |
+
|
10 |
+
|
11 |
+
class TransformerLM(torch.nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
text_encoder_input_size: int,
|
15 |
+
llm_input_size: int,
|
16 |
+
llm_output_size: int,
|
17 |
+
text_token_size: int,
|
18 |
+
speech_token_size: int,
|
19 |
+
text_encoder: torch.nn.Module,
|
20 |
+
llm: torch.nn.Module,
|
21 |
+
sampling: Callable,
|
22 |
+
length_normalized_loss: bool = True,
|
23 |
+
lsm_weight: float = 0.0,
|
24 |
+
spk_embed_dim: int = 192,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.llm_input_size = llm_input_size
|
28 |
+
self.speech_token_size = speech_token_size
|
29 |
+
# 1. build text token inputs related modules
|
30 |
+
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
|
31 |
+
self.text_encoder = text_encoder
|
32 |
+
self.text_encoder_affine_layer = nn.Linear(
|
33 |
+
self.text_encoder.output_size(),
|
34 |
+
llm_input_size
|
35 |
+
)
|
36 |
+
|
37 |
+
# 2. build speech token language model related modules
|
38 |
+
self.sos_eos = 0
|
39 |
+
self.task_id = 1
|
40 |
+
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
41 |
+
self.llm = llm
|
42 |
+
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
43 |
+
self.criterion_ce = LabelSmoothingLoss(
|
44 |
+
size=speech_token_size + 1,
|
45 |
+
padding_idx=IGNORE_ID,
|
46 |
+
smoothing=lsm_weight,
|
47 |
+
normalize_length=length_normalized_loss,
|
48 |
+
)
|
49 |
+
|
50 |
+
# 3. [Optional] build speech token related modules
|
51 |
+
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
52 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
53 |
+
|
54 |
+
# 4. sampling method
|
55 |
+
self.sampling = sampling
|
56 |
+
|
57 |
+
def encode(
|
58 |
+
self,
|
59 |
+
text: torch.Tensor,
|
60 |
+
text_lengths: torch.Tensor,
|
61 |
+
):
|
62 |
+
encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
63 |
+
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
64 |
+
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
65 |
+
return encoder_out, encoder_out_lens
|
66 |
+
|
67 |
+
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
68 |
+
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
69 |
+
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
70 |
+
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
71 |
+
for i in range(len(text_token))]
|
72 |
+
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
73 |
+
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
74 |
+
return lm_input, lm_input_len
|
75 |
+
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
batch: dict,
|
79 |
+
device: torch.device,
|
80 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
text: (B, L, D)
|
84 |
+
text_lengths: (B,)
|
85 |
+
audio: (B, T, N) or (B, T)
|
86 |
+
audio_lengths: (B,)
|
87 |
+
"""
|
88 |
+
text_token = batch['text_token'].to(device)
|
89 |
+
text_token_len = batch['text_token_len'].to(device)
|
90 |
+
speech_token = batch['speech_token'].to(device)
|
91 |
+
speech_token_len = batch['speech_token_len'].to(device)
|
92 |
+
embedding = batch['embedding'].to(device)
|
93 |
+
|
94 |
+
# 1. prepare llm_target
|
95 |
+
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
96 |
+
[self.speech_token_size]) for i in range(text_token.size(0))]
|
97 |
+
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
98 |
+
|
99 |
+
# 1. encode text_token
|
100 |
+
text_token = self.text_embedding(text_token)
|
101 |
+
text_token, text_token_len = self.encode(text_token, text_token_len)
|
102 |
+
|
103 |
+
# 2. embedding projection
|
104 |
+
embedding = F.normalize(embedding, dim=1)
|
105 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
106 |
+
embedding = embedding.unsqueeze(1)
|
107 |
+
|
108 |
+
# 3. eos and task_id
|
109 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
110 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
111 |
+
|
112 |
+
# 4. encode speech_token
|
113 |
+
speech_token = self.speech_embedding(speech_token)
|
114 |
+
|
115 |
+
# 5. unpad and pad
|
116 |
+
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
117 |
+
task_id_emb, speech_token, speech_token_len)
|
118 |
+
|
119 |
+
# 6. run lm forward
|
120 |
+
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
121 |
+
logits = self.llm_decoder(lm_output)
|
122 |
+
loss = self.criterion_ce(logits, lm_target)
|
123 |
+
acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
|
124 |
+
return {'loss': loss, 'acc': acc}
|
125 |
+
|
126 |
+
def sampling_ids(
|
127 |
+
self,
|
128 |
+
weighted_scores: torch.Tensor,
|
129 |
+
decoded_tokens: List,
|
130 |
+
sampling: int,
|
131 |
+
ignore_eos: bool = True,
|
132 |
+
):
|
133 |
+
while True:
|
134 |
+
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
135 |
+
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
136 |
+
break
|
137 |
+
return top_ids
|
138 |
+
|
139 |
+
@torch.inference_mode()
|
140 |
+
def inference(
|
141 |
+
self,
|
142 |
+
text: torch.Tensor,
|
143 |
+
text_len: torch.Tensor,
|
144 |
+
prompt_text: torch.Tensor,
|
145 |
+
prompt_text_len: torch.Tensor,
|
146 |
+
prompt_speech_token: torch.Tensor,
|
147 |
+
prompt_speech_token_len: torch.Tensor,
|
148 |
+
embedding: torch.Tensor,
|
149 |
+
sampling: int = 25,
|
150 |
+
max_token_text_ratio: float = 20,
|
151 |
+
min_token_text_ratio: float = 2,
|
152 |
+
) -> Generator[torch.Tensor, None, None]:
|
153 |
+
device = text.device
|
154 |
+
text = torch.concat([prompt_text, text], dim=1)
|
155 |
+
text_len += prompt_text_len
|
156 |
+
text = self.text_embedding(text)
|
157 |
+
|
158 |
+
# 1. encode text
|
159 |
+
text, text_len = self.encode(text, text_len)
|
160 |
+
|
161 |
+
# 2. encode embedding
|
162 |
+
if embedding.shape[0] != 0:
|
163 |
+
embedding = F.normalize(embedding, dim=1)
|
164 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
165 |
+
embedding = embedding.unsqueeze(dim=1)
|
166 |
+
else:
|
167 |
+
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
168 |
+
|
169 |
+
# 3. concat llm_input
|
170 |
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
171 |
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
172 |
+
if prompt_speech_token_len != 0:
|
173 |
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
174 |
+
else:
|
175 |
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
176 |
+
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
177 |
+
|
178 |
+
# 4. cal min/max_length
|
179 |
+
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
180 |
+
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
181 |
+
|
182 |
+
# 5. step by step decode
|
183 |
+
out_tokens = []
|
184 |
+
offset = 0
|
185 |
+
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
186 |
+
for i in range(max_len):
|
187 |
+
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
188 |
+
att_cache=att_cache, cnn_cache=cnn_cache,
|
189 |
+
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
190 |
+
device=lm_input.device)).to(torch.bool))
|
191 |
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
192 |
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
193 |
+
if top_ids == self.speech_token_size:
|
194 |
+
break
|
195 |
+
# in stream mode, yield token one by one
|
196 |
+
yield top_ids
|
197 |
+
out_tokens.append(top_ids)
|
198 |
+
offset += lm_input.size(1)
|
199 |
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
VietTTS/model.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import threading
|
5 |
+
import time
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from contextlib import nullcontext
|
8 |
+
import uuid
|
9 |
+
from VietTTS.utils.common import fade_in_out_audio
|
10 |
+
|
11 |
+
class TTSModel:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
llm: torch.nn.Module,
|
15 |
+
flow: torch.nn.Module,
|
16 |
+
hift: torch.nn.Module
|
17 |
+
):
|
18 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
+
self.llm = llm
|
20 |
+
self.flow = flow
|
21 |
+
self.hift = hift
|
22 |
+
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
23 |
+
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
24 |
+
self.token_overlap_len = 20
|
25 |
+
# mel fade in out
|
26 |
+
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
27 |
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
28 |
+
# hift cache
|
29 |
+
self.mel_cache_len = 20
|
30 |
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
31 |
+
# speech fade in out
|
32 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
33 |
+
# rtf and decoding related
|
34 |
+
self.stream_scale_factor = 1
|
35 |
+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
36 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
37 |
+
self.lock = threading.Lock()
|
38 |
+
# dict used to store session related variable
|
39 |
+
self.tts_speech_token_dict = {}
|
40 |
+
self.llm_end_dict = {}
|
41 |
+
self.mel_overlap_dict = {}
|
42 |
+
self.hift_cache_dict = {}
|
43 |
+
|
44 |
+
def load(self, llm_model, flow_model, hift_model):
|
45 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
46 |
+
self.llm.to(self.device).eval()
|
47 |
+
self.llm.half()
|
48 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
49 |
+
self.flow.to(self.device).eval()
|
50 |
+
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
51 |
+
self.hift.to(self.device).eval()
|
52 |
+
|
53 |
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
54 |
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
55 |
+
self.llm.text_encoder = llm_text_encoder
|
56 |
+
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
57 |
+
self.llm.llm = llm_llm
|
58 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
59 |
+
self.flow.encoder = flow_encoder
|
60 |
+
|
61 |
+
def load_onnx(self, flow_decoder_estimator_model):
|
62 |
+
import onnxruntime
|
63 |
+
option = onnxruntime.SessionOptions()
|
64 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
65 |
+
option.intra_op_num_threads = 1
|
66 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
67 |
+
del self.flow.decoder.estimator
|
68 |
+
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
69 |
+
|
70 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
71 |
+
with self.llm_context:
|
72 |
+
for i in self.llm.inference(
|
73 |
+
text=text.to(self.device),
|
74 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
75 |
+
prompt_text=prompt_text.to(self.device),
|
76 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
77 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
78 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
79 |
+
embedding=llm_embedding.to(self.device).half()
|
80 |
+
):
|
81 |
+
self.tts_speech_token_dict[uuid].append(i)
|
82 |
+
self.llm_end_dict[uuid] = True
|
83 |
+
|
84 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
85 |
+
tts_mel = self.flow.inference(
|
86 |
+
token=token.to(self.device),
|
87 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
88 |
+
prompt_token=prompt_token.to(self.device),
|
89 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
90 |
+
prompt_feat=prompt_feat.to(self.device),
|
91 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
92 |
+
embedding=embedding.to(self.device)
|
93 |
+
)
|
94 |
+
|
95 |
+
if self.hift_cache_dict[uuid] is not None:
|
96 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
97 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
98 |
+
else:
|
99 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
100 |
+
|
101 |
+
if finalize is False:
|
102 |
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
103 |
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
104 |
+
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
105 |
+
self.hift_cache_dict[uuid] = {
|
106 |
+
'mel': tts_mel[:, :, -self.mel_cache_len:],
|
107 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
108 |
+
'speech': tts_speech[:, -self.source_cache_len:]
|
109 |
+
}
|
110 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
111 |
+
else:
|
112 |
+
if speed != 1.0:
|
113 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
114 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
115 |
+
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
116 |
+
|
117 |
+
tts_speech = fade_in_out_audio(tts_speech)
|
118 |
+
return tts_speech
|
119 |
+
|
120 |
+
def tts(
|
121 |
+
self,
|
122 |
+
text: str,
|
123 |
+
flow_embedding: torch.Tensor,
|
124 |
+
llm_embedding: torch.Tensor=torch.zeros(0, 192),
|
125 |
+
prompt_text: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
|
126 |
+
llm_prompt_speech_token: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
|
127 |
+
flow_prompt_speech_token: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
|
128 |
+
prompt_speech_feat: torch.Tensor=torch.zeros(1, 0, 80),
|
129 |
+
stream: bool=False,
|
130 |
+
speed: float=1.0,
|
131 |
+
**kwargs
|
132 |
+
):
|
133 |
+
# this_uuid is used to track variables related to this inference thread
|
134 |
+
this_uuid = str(uuid.uuid1())
|
135 |
+
with self.lock:
|
136 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
137 |
+
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
138 |
+
|
139 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
140 |
+
p.start()
|
141 |
+
|
142 |
+
if stream:
|
143 |
+
token_hop_len = self.token_min_hop_len
|
144 |
+
while True:
|
145 |
+
time.sleep(0.01)
|
146 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
147 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]).unsqueeze(dim=0)
|
148 |
+
this_tts_speech = self.token2wav(
|
149 |
+
token=this_tts_speech_token,
|
150 |
+
prompt_token=flow_prompt_speech_token,
|
151 |
+
prompt_feat=prompt_speech_feat,
|
152 |
+
embedding=flow_embedding,
|
153 |
+
uuid=this_uuid,
|
154 |
+
finalize=False
|
155 |
+
)
|
156 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
157 |
+
with self.lock:
|
158 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
159 |
+
# increase token_hop_len for better speech quality
|
160 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
161 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
162 |
+
break
|
163 |
+
p.join()
|
164 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
165 |
+
this_tts_speech = self.token2wav(
|
166 |
+
token=this_tts_speech_token,
|
167 |
+
prompt_token=flow_prompt_speech_token,
|
168 |
+
prompt_feat=prompt_speech_feat,
|
169 |
+
embedding=flow_embedding,
|
170 |
+
uuid=this_uuid,
|
171 |
+
finalize=True
|
172 |
+
)
|
173 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
174 |
+
else:
|
175 |
+
p.join()
|
176 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
177 |
+
this_tts_speech = self.token2wav(
|
178 |
+
token=this_tts_speech_token,
|
179 |
+
prompt_token=flow_prompt_speech_token,
|
180 |
+
prompt_feat=prompt_speech_feat,
|
181 |
+
embedding=flow_embedding,
|
182 |
+
uuid=this_uuid,
|
183 |
+
finalize=True,
|
184 |
+
speed=speed
|
185 |
+
)
|
186 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
187 |
+
|
188 |
+
with self.lock:
|
189 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
190 |
+
self.llm_end_dict.pop(this_uuid)
|
191 |
+
self.mel_overlap_dict.pop(this_uuid)
|
192 |
+
self.hift_cache_dict.pop(this_uuid)
|
193 |
+
|
194 |
+
def vc(
|
195 |
+
self,
|
196 |
+
source_speech_token: torch.Tensor,
|
197 |
+
flow_prompt_speech_token: torch.Tensor,
|
198 |
+
prompt_speech_feat: torch.Tensor,
|
199 |
+
flow_embedding: torch.Tensor,
|
200 |
+
stream: bool=False,
|
201 |
+
speed: float=1.0,
|
202 |
+
**kwargs
|
203 |
+
):
|
204 |
+
this_uuid = str(uuid.uuid1())
|
205 |
+
with self.lock:
|
206 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
207 |
+
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
208 |
+
|
209 |
+
if stream:
|
210 |
+
token_hop_len = self.token_min_hop_len
|
211 |
+
while True:
|
212 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
213 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
214 |
+
.unsqueeze(dim=0)
|
215 |
+
this_tts_speech = self.token2wav(
|
216 |
+
token=this_tts_speech_token,
|
217 |
+
prompt_token=flow_prompt_speech_token,
|
218 |
+
prompt_feat=prompt_speech_feat,
|
219 |
+
embedding=flow_embedding,
|
220 |
+
uuid=this_uuid,
|
221 |
+
finalize=False
|
222 |
+
)
|
223 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
224 |
+
with self.lock:
|
225 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
226 |
+
# increase token_hop_len for better speech quality
|
227 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
228 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
229 |
+
break
|
230 |
+
|
231 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
232 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
|
233 |
+
this_tts_speech = self.token2wav(
|
234 |
+
token=this_tts_speech_token,
|
235 |
+
prompt_token=flow_prompt_speech_token,
|
236 |
+
prompt_feat=prompt_speech_feat,
|
237 |
+
embedding=flow_embedding,
|
238 |
+
uuid=this_uuid,
|
239 |
+
finalize=True
|
240 |
+
)
|
241 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
242 |
+
else:
|
243 |
+
# deal with all tokens
|
244 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
245 |
+
this_tts_speech = self.token2wav(
|
246 |
+
token=this_tts_speech_token,
|
247 |
+
prompt_token=flow_prompt_speech_token,
|
248 |
+
prompt_feat=prompt_speech_feat,
|
249 |
+
embedding=flow_embedding,
|
250 |
+
uuid=this_uuid,
|
251 |
+
finalize=True,
|
252 |
+
speed=speed
|
253 |
+
)
|
254 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
255 |
+
|
256 |
+
with self.lock:
|
257 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
258 |
+
self.llm_end_dict.pop(this_uuid)
|
259 |
+
self.mel_overlap_dict.pop(this_uuid)
|
260 |
+
self.hift_cache_dict.pop(this_uuid)
|
VietTTS/models/.cache/huggingface/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*
|
VietTTS/models/.cache/huggingface/download/.gitattributes.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/.gitattributes.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
a6344aac8c09253b3b630fb776ae94478aa0275b
|
3 |
+
1752523446.3440115
|
VietTTS/models/.cache/huggingface/download/README.md.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/README.md.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
41d572572806e892430392203f5e635d461b028a
|
3 |
+
1752523446.4585946
|
VietTTS/models/.cache/huggingface/download/README_VN.md.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/README_VN.md.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
68fa27e83e1f31dd41485b567a7475d5abab3732
|
3 |
+
1752523446.479747
|
VietTTS/models/.cache/huggingface/download/config.yaml.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/config.yaml.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
062687bc82eb8d2c4ccae395158f2066f4634390
|
3 |
+
1752523445.83575
|
VietTTS/models/.cache/huggingface/download/flow.pt.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/flow.pt.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
1411de192039a21d53f0bf1968feb50586ce71d81ea1443f8163f4d1c46c5455
|
3 |
+
1752523560.376818
|
VietTTS/models/.cache/huggingface/download/hift.pt.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/hift.pt.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
91e679b6ca1eff71187ffb4f3ab0444935594cdcc20a9bd12afad111ef8d6012
|
3 |
+
1752523474.8549578
|
VietTTS/models/.cache/huggingface/download/llm.pt.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/llm.pt.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
c1773e5afe16a88ee82e33cf510a07717ce1346d2e74856733d72dc297a9a017
|
3 |
+
1752523690.911262
|
VietTTS/models/.cache/huggingface/download/speech_embedding.onnx.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/speech_embedding.onnx.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
|
3 |
+
1752523473.7760808
|
VietTTS/models/.cache/huggingface/download/speech_tokenizer.onnx.lock
ADDED
File without changes
|
VietTTS/models/.cache/huggingface/download/speech_tokenizer.onnx.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
b9f49bb2ef682e162969a6919b0ed2a51a758729
|
2 |
+
56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486
|
3 |
+
1752523583.6788416
|
VietTTS/models/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
VietTTS/models/README.md
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- vi
|
4 |
+
- en
|
5 |
+
pipeline_tag: text-to-speech
|
6 |
+
license: apache-2.0
|
7 |
+
tags:
|
8 |
+
- tts
|
9 |
+
- text-to-speech
|
10 |
+
- vietnamese
|
11 |
+
- speech-synthesis
|
12 |
+
- speech,
|
13 |
+
- viet-tts
|
14 |
+
- viettts
|
15 |
+
---
|
16 |
+
<!-- # VietTTS: An Open-Source Vietnamese Text to Speech -->
|
17 |
+
<p align="center">
|
18 |
+
<img src="https://github.com/dangvansam/viet-tts/blob/main/assets/viet-tts-medium.png?raw=true" style="width: 200px">
|
19 |
+
<h1 align="center"style="color: white; font-weight: bold; font-family:roboto"><span style="color: white; font-weight: bold; font-family:roboto">VietTTS</span>: An Open-Source Vietnamese Text to Speech</h1>
|
20 |
+
</p>
|
21 |
+
<p align="center">
|
22 |
+
<a href="https://github.com/dangvansam/viet-tts"><img src="https://img.shields.io/github/stars/dangvansam/viet-tts?style=social"></a>
|
23 |
+
<a href="LICENSE"><img src="https://img.shields.io/github/license/dangvansam/viet-asr"></a>
|
24 |
+
<a href="https://huggingface.co/dangvansam/viet-tts/blob/main/README_VN.md"><img src="https://img.shields.io/badge/README-Tiếng Việt-blue"></a>
|
25 |
+
</p>
|
26 |
+
|
27 |
+
**VietTTS** is an open-source toolkit providing the community with a powerful Vietnamese TTS model, capable of natural voice synthesis and robust voice cloning. Designed for effective experimentation, **VietTTS** supports research and application in Vietnamese voice technologies.
|
28 |
+
|
29 |
+
## ⭐ Key Features
|
30 |
+
- **TTS**: Text-to-Speech generation with any voice via prompt audio
|
31 |
+
- **OpenAI-API-compatible**: Compatible with OpenAI's Text-to-Speech API format
|
32 |
+
|
33 |
+
## 🛠️ Installation
|
34 |
+
|
35 |
+
VietTTS can be installed via a Python installer (Linux only, with Windows and macOS support coming soon) or Docker.
|
36 |
+
|
37 |
+
### Python Installer (Python>=3.10)
|
38 |
+
```bash
|
39 |
+
git clone https://github.com/dangvansam/viet-tts.git
|
40 |
+
cd viet-tts
|
41 |
+
|
42 |
+
# (Optional) Install Python environment with conda, you could also use virtualenv
|
43 |
+
conda create --name viettts python=3.10
|
44 |
+
conda activate viettts
|
45 |
+
|
46 |
+
# Install
|
47 |
+
pip install -e . && pip cache purge
|
48 |
+
```
|
49 |
+
|
50 |
+
### Docker
|
51 |
+
|
52 |
+
1. Install [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), and [CUDA](https://developer.nvidia.com/cuda-downloads).
|
53 |
+
|
54 |
+
2. Run the following commands:
|
55 |
+
```bash
|
56 |
+
git clone https://github.com/dangvansam/viet-tts.git
|
57 |
+
cd viet-tts
|
58 |
+
|
59 |
+
# Build docker images
|
60 |
+
docker compose build
|
61 |
+
|
62 |
+
# Run with docker-compose - will create server at: http://localhost:8298
|
63 |
+
docker compose up -d
|
64 |
+
|
65 |
+
# Or run with docker run - will create server at: http://localhost:8298
|
66 |
+
docker run -itd --gpu=alls -p 8298:8298 -v ./pretrained-models:/app/pretrained-models -n viet-tts-service viet-tts:latest viettts server --host 0.0.0.0 --port 8298
|
67 |
+
```
|
68 |
+
|
69 |
+
## 🚀 Usage
|
70 |
+
|
71 |
+
### Built-in Voices 🤠
|
72 |
+
You can use available voices bellow to synthesize speech.
|
73 |
+
<details>
|
74 |
+
<summary>Expand</summary>
|
75 |
+
|
76 |
+
| ID | Voice | Gender | Play Audio |
|
77 |
+
|-----|-----------------------|--------|--------------------------------------------------|
|
78 |
+
| 1 | nsnd-le-chuc | 👨 | <audio controls src="samples/nsnd-le-chuc.mp3"></audio> |
|
79 |
+
| 2 | speechify_10 | 👩 | <audio controls src="samples/speechify_10.wav"></audio> |
|
80 |
+
| 3 | atuan | 👨 | <audio controls src="samples/atuan.wav"></audio> |
|
81 |
+
| 4 | speechify_11 | 👩 | <audio controls src="samples/speechify_11.wav"></audio> |
|
82 |
+
| 5 | cdteam | 👨 | <audio controls src="samples/cdteam.wav"></audio> |
|
83 |
+
| 6 | speechify_12 | 👩 | <audio controls src="samples/speechify_12.wav"></audio> |
|
84 |
+
| 7 | cross_lingual_prompt | 👩 | <audio controls src="samples/cross_lingual_prompt.wav"></audio> |
|
85 |
+
| 8 | speechify_2 | 👩 | <audio controls src="samples/speechify_2.wav"></audio> |
|
86 |
+
| 9 | diep-chi | 👨 | <audio controls src="samples/diep-chi.wav"></audio> |
|
87 |
+
| 10 | speechify_3 | 👩 | <audio controls src="samples/speechify_3.wav"></audio> |
|
88 |
+
| 11 | doremon | 👨 | <audio controls src="samples/doremon.mp3"></audio> |
|
89 |
+
| 12 | speechify_4 | 👩 | <audio controls src="samples/speechify_4.wav"></audio> |
|
90 |
+
| 13 | jack-sparrow | 👨 | <audio controls src="samples/jack-sparrow.mp3"></audio> |
|
91 |
+
| 14 | speechify_5 | 👩 | <audio controls src="samples/speechify_5.wav"></audio> |
|
92 |
+
| 15 | nguyen-ngoc-ngan | 👩 | <audio controls src="samples/nguyen-ngoc-ngan.wav"></audio> |
|
93 |
+
| 16 | speechify_6 | 👩 | <audio controls src="samples/speechify_6.wav"></audio> |
|
94 |
+
| 17 | nu-nhe-nhang | 👩 | <audio controls src="samples/nu-nhe-nhang.wav"></audio> |
|
95 |
+
| 18 | speechify_7 | 👩 | <audio controls src="samples/speechify_7.wav"></audio> |
|
96 |
+
| 19 | quynh | 👩 | <audio controls src="samples/quynh.wav"></audio> |
|
97 |
+
| 20 | speechify_8 | 👩 | <audio controls src="samples/speechify_8.wav"></audio> |
|
98 |
+
| 21 | speechify_9 | 👩 | <audio controls src="samples/speechify_9.wav"></audio> |
|
99 |
+
| 22 | son-tung-mtp | 👨 | <audio controls src="samples/son-tung-mtp.wav"></audio> |
|
100 |
+
| 23 | zero_shot_prompt | 👩 | <audio controls src="samples/zero_shot_prompt.wav"></audio> |
|
101 |
+
| 24 | speechify_1 | 👩 | <audio controls src="samples/speechify_1.wav"></audio> |
|
102 |
+
|
103 |
+
<div>
|
104 |
+
</div>
|
105 |
+
</details>
|
106 |
+
|
107 |
+
### Command Line Interface (CLI)
|
108 |
+
The VietTTS Command Line Interface (CLI) allows you to quickly generate speech directly from the terminal. Here's how to use it:
|
109 |
+
```bash
|
110 |
+
# Usage
|
111 |
+
viettts --help
|
112 |
+
|
113 |
+
# Start API Server
|
114 |
+
viettts server --host 0.0.0.0 --port 8298
|
115 |
+
|
116 |
+
# List all built-in voices
|
117 |
+
viettts show-voices
|
118 |
+
|
119 |
+
# Synthesize speech from text with built-in voices
|
120 |
+
viettts synthesis --text "Xin chào" --voice 0 --output test.wav
|
121 |
+
|
122 |
+
# Clone voice from a local audio file
|
123 |
+
viettts synthesis --text "Xin chào" --voice Download/voice.wav --output cloned.wav
|
124 |
+
```
|
125 |
+
|
126 |
+
### API Client
|
127 |
+
#### Python (OpenAI Client)
|
128 |
+
You need to set environment variables for the OpenAI Client:
|
129 |
+
```bash
|
130 |
+
# Set base_url and API key as environment variables
|
131 |
+
export OPENAI_BASE_URL=http://localhost:8298
|
132 |
+
export OPENAI_API_KEY=viet-tts # not use in current version
|
133 |
+
```
|
134 |
+
To create speech from input text:
|
135 |
+
```python
|
136 |
+
from pathlib import Path
|
137 |
+
from openai import OpenAI
|
138 |
+
|
139 |
+
client = OpenAI()
|
140 |
+
|
141 |
+
output_file_path = Path(__file__).parent / "speech.wav"
|
142 |
+
|
143 |
+
with client.audio.speech.with_streaming_response.create(
|
144 |
+
model='tts-1',
|
145 |
+
voice='cdteam',
|
146 |
+
input='Xin chào Việt Nam.',
|
147 |
+
speed=1.0,
|
148 |
+
response_format='wav'
|
149 |
+
) as response:
|
150 |
+
response.stream_to_file('a.wav')
|
151 |
+
```
|
152 |
+
|
153 |
+
#### CURL
|
154 |
+
```bash
|
155 |
+
# Get all built-in voices
|
156 |
+
curl --location http://0.0.0.0:8298/v1/voices
|
157 |
+
|
158 |
+
# OpenAI format (bult-in voices)
|
159 |
+
curl http://localhost:8298/v1/audio/speech \
|
160 |
+
-H "Authorization: Bearer viet-tts" \
|
161 |
+
-H "Content-Type: application/json" \
|
162 |
+
-d '{
|
163 |
+
"model": "tts-1",
|
164 |
+
"input": "Xin chào Việt Nam.",
|
165 |
+
"voice": "son-tung-mtp"
|
166 |
+
}' \
|
167 |
+
--output speech.wav
|
168 |
+
|
169 |
+
# API with voice from local file
|
170 |
+
curl --location http://0.0.0.0:8298/v1/tts \
|
171 |
+
--form 'text="xin chào"' \
|
172 |
+
--form 'audio_file=@"/home/viettts/Downloads/voice.mp4"' \
|
173 |
+
--output speech.wav
|
174 |
+
```
|
175 |
+
|
176 |
+
#### Node
|
177 |
+
```js
|
178 |
+
import fs from "fs";
|
179 |
+
import path from "path";
|
180 |
+
import OpenAI from "openai";
|
181 |
+
|
182 |
+
const openai = new OpenAI();
|
183 |
+
|
184 |
+
const speechFile = path.resolve("./speech.wav");
|
185 |
+
|
186 |
+
async function main() {
|
187 |
+
const mp3 = await openai.audio.speech.create({
|
188 |
+
model: "tts-1",
|
189 |
+
voice: "1",
|
190 |
+
input: "Xin chào Việt Nam.",
|
191 |
+
});
|
192 |
+
console.log(speechFile);
|
193 |
+
const buffer = Buffer.from(await mp3.arrayBuffer());
|
194 |
+
await fs.promises.writeFile(speechFile, buffer);
|
195 |
+
}
|
196 |
+
main();
|
197 |
+
```
|
198 |
+
|
199 |
+
## 🙏 Acknowledgement
|
200 |
+
- 💡 Borrowed code from [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
|
201 |
+
- 🎙️ VAD model from [silero-vad](https://github.com/snakers4/silero-vad)
|
202 |
+
- 📝 Text normalization with [Vinorm](https://github.com/v-nhandt21/Vinorm)
|
203 |
+
|
204 |
+
## 📜 License
|
205 |
+
The **VietTTS** source code is released under the **Apache 2.0 License**. Pre-trained models and audio samples are licensed under the **CC BY-NC License**, based on an in-the-wild dataset. We apologize for any inconvenience this may cause.
|
206 |
+
|
207 |
+
## ⚠️ Disclaimer
|
208 |
+
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
209 |
+
|
210 |
+
## 💬 Contact
|
211 |
+
- Facebook: https://fb.com/sam.rngd
|
212 |
+
- GitHub: https://github.com/dangvansam
|
213 |
+
- Email: [email protected]
|
VietTTS/models/README_VN.md
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="https://github.com/dangvansam/viet-tts/blob/main/assets/viet-tts-medium.png?raw=true" style="width: 200px">
|
3 |
+
<h1 align="center" style="color: white; font-weight: bold; font-family:roboto"><span style="color: white; font-weight: bold; font-family:roboto">VietTTS</span>: Công cụ chuyển văn bản thành giọng nói tiếng Việt mã nguồn mở</h1>
|
4 |
+
</p>
|
5 |
+
<p align="center">
|
6 |
+
<a href="https://github.com/dangvansam/viet-tts"><img src="https://img.shields.io/github/stars/dangvansam/viet-tts?style=social"></a>
|
7 |
+
<a href="LICENSE"><img src="https://img.shields.io/github/license/dangvansam/viet-asr"></a>
|
8 |
+
<a href="https://huggingface.co/dangvansam/viet-tts/blob/main/README.md"><img src="https://img.shields.io/badge/README-English-blue"></a>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
**VietTTS** là một bộ công cụ mã nguồn mở cung cấp mô hình TTS tiếng Việt mạnh mẽ, cho phép tổng hợp giọng nói tự nhiên và tạo giọng nói mới. **VietTTS** hỗ trợ nghiên cứu và ứng dụng trong công nghệ giọng nói tiếng Việt.
|
12 |
+
|
13 |
+
## ⭐ Tính năng nổi bật
|
14 |
+
- **TTS**: Tổng hợp giọng nói từ văn bản với bất kỳ giọng nào qua audio mẫu
|
15 |
+
- **OpenAI-API-compatible**: Tương thích với API Text to Speech OpenAI
|
16 |
+
|
17 |
+
## 🛠️ Cài đặt
|
18 |
+
VietTTS có thể được cài đặt qua trình cài đặt Python (chỉ hỗ trợ Linux, Windows và macOS sẽ có trong tương lai) hoặc Docker.
|
19 |
+
|
20 |
+
### Trình cài đặt Python (Python>=3.10)
|
21 |
+
|
22 |
+
```bash
|
23 |
+
git clone https://github.com/dangvansam/viet-tts.git
|
24 |
+
cd viet-tts
|
25 |
+
|
26 |
+
# (Tùy chọn) Tạo môi trường Python với conda hoặc dùng virtualenv
|
27 |
+
conda create --name viettts python=3.10
|
28 |
+
conda activate viettts
|
29 |
+
|
30 |
+
# Cài đặt
|
31 |
+
pip install -e . && pip cache purge
|
32 |
+
```
|
33 |
+
|
34 |
+
### Docker
|
35 |
+
1. Cài đặt [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), và [CUDA](https://developer.nvidia.com/cuda-downloads).
|
36 |
+
|
37 |
+
2. Chạy các lệnh sau:
|
38 |
+
```bash
|
39 |
+
git clone https://github.com/dangvansam/viet-tts.git
|
40 |
+
cd viet-tts
|
41 |
+
|
42 |
+
# Xây dựng hình ảnh docker
|
43 |
+
docker compose build
|
44 |
+
|
45 |
+
# Chạy bằng docker-compose - tạo server tại: http://localhost:8298
|
46 |
+
docker compose up -d
|
47 |
+
|
48 |
+
# Chạy bằng docker run - tạo server tại: http://localhost:8298
|
49 |
+
docker run -itd --gpu=alls -p 8298:8298 -v ./pretrained-models:/app/pretrained-models -n viet-tts-service viet-tts:latest viettts server --host 0.0.0.0 --port 8298
|
50 |
+
```
|
51 |
+
|
52 |
+
## 🚀 Sử dụng
|
53 |
+
|
54 |
+
### Giọng nói tích hợp 🤠
|
55 |
+
Bạn có thể sử dụng các giọng nói có sẵn dưới đây để tổng hợp giọng nói.
|
56 |
+
<details>
|
57 |
+
<summary>Mở rộng</summary>
|
58 |
+
|
59 |
+
| ID | Giọng | Giới tính | Phát âm thanh |
|
60 |
+
|-----|--------------------------|-----------|-------------------------------------------------|
|
61 |
+
| 1 | nsnd-le-chuc | 👨 | <audio controls src="samples/nsnd-le-chuc.mp3"></audio> |
|
62 |
+
| 2 | speechify_10 | 👩 | <audio controls src="samples/speechify_10.wav"></audio> |
|
63 |
+
| 3 | atuan | 👨 | <audio controls src="samples/atuan.wav"></audio> |
|
64 |
+
| 4 | speechify_11 | 👩 | <audio controls src="samples/speechify_11.wav"></audio> |
|
65 |
+
| 5 | cdteam | 👨 | <audio controls src="samples/cdteam.wav"></audio> |
|
66 |
+
| 6 | speechify_12 | 👩 | <audio controls src="samples/speechify_12.wav"></audio> |
|
67 |
+
| 7 | cross_lingual_prompt | 👩 | <audio controls src="samples/cross_lingual_prompt.wav"></audio> |
|
68 |
+
| 8 | speechify_2 | 👩 | <audio controls src="samples/speechify_2.wav"></audio> |
|
69 |
+
| 9 | diep-chi | 👨 | <audio controls src="samples/diep-chi.wav"></audio> |
|
70 |
+
| 10 | speechify_3 | 👩 | <audio controls src="samples/speechify_3.wav"></audio> |
|
71 |
+
| 11 | doremon | 👨 | <audio controls src="samples/doremon.mp3"></audio> |
|
72 |
+
| 12 | speechify_4 | 👩 | <audio controls src="samples/speechify_4.wav"></audio> |
|
73 |
+
| 13 | jack-sparrow | 👨 | <audio controls src="samples/jack-sparrow.mp3"></audio> |
|
74 |
+
| 14 | speechify_5 | 👩 | <audio controls src="samples/speechify_5.wav"></audio> |
|
75 |
+
| 15 | nguyen-ngoc-ngan | 👩 | <audio controls src="samples/nguyen-ngoc-ngan.wav"></audio> |
|
76 |
+
| 16 | speechify_6 | 👩 | <audio controls src="samples/speechify_6.wav"></audio> |
|
77 |
+
| 17 | nu-nhe-nhang | 👩 | <audio controls src="samples/nu-nhe-nhang.wav"></audio> |
|
78 |
+
| 18 | speechify_7 | 👩 | <audio controls src="samples/speechify_7.wav"></audio> |
|
79 |
+
| 19 | quynh | 👩 | <audio controls src="samples/quynh.wav"></audio> |
|
80 |
+
| 20 | speechify_8 | 👩 | <audio controls src="samples/speechify_8.wav"></audio> |
|
81 |
+
| 21 | speechify_9 | 👩 | <audio controls src="samples/speechify_9.wav"></audio> |
|
82 |
+
| 22 | son-tung-mtp | 👨 | <audio controls src="samples/son-tung-mtp.wav"></audio> |
|
83 |
+
| 23 | zero_shot_prompt | 👩 | <audio controls src="samples/zero_shot_prompt.wav"></audio> |
|
84 |
+
| 24 | speechify_1 | 👩 | <audio controls src="samples/speechify_1.wav"></audio> |
|
85 |
+
|
86 |
+
<div>
|
87 |
+
|
88 |
+
</div>
|
89 |
+
|
90 |
+
</details>
|
91 |
+
|
92 |
+
### Thực thi với lệnh (CLI)
|
93 |
+
|
94 |
+
Giao diện dòng lệnh VietTTS cho phép bạn tạo giọng nói từ terminal. Cách sử dụng:
|
95 |
+
|
96 |
+
```bash
|
97 |
+
# Hướng dẫn sử dụng
|
98 |
+
viettts --help
|
99 |
+
|
100 |
+
# Khởi động API Server
|
101 |
+
viettts server --host 0.0.0.0 --port 8298
|
102 |
+
|
103 |
+
# Xem tất cả các giọng nói có sẵn
|
104 |
+
viettts show-voices
|
105 |
+
|
106 |
+
# Tổng hợp giọng nói từ văn bản với giọng có sẵn
|
107 |
+
viettts synthesis --text "Xin chào" --voice 0 --output test.wav
|
108 |
+
|
109 |
+
# Sao chép giọng từ audio file bất kì
|
110 |
+
viettts synthesis --text "Xin chào" --voice Download/voice.wav --output cloned.wav
|
111 |
+
```
|
112 |
+
|
113 |
+
### API Client
|
114 |
+
#### Python (OpenAI Client)
|
115 |
+
Thiết lập biến môi trường cho OpenAI Client:
|
116 |
+
|
117 |
+
```bash
|
118 |
+
# Thiết lập base_url và API key như biến môi trường
|
119 |
+
export OPENAI_BASE_URL=http://localhost:8298
|
120 |
+
export OPENAI_API_KEY=viet-tts # không dùng trong phiên bản hiện tại
|
121 |
+
```
|
122 |
+
|
123 |
+
Để tạo giọng nói từ văn bản đầu vào:
|
124 |
+
|
125 |
+
```python
|
126 |
+
from pathlib import Path
|
127 |
+
from openai import OpenAI
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
client = OpenAI()
|
132 |
+
output_file_path = Path(__file__).parent / "speech.wav"
|
133 |
+
|
134 |
+
with client.audio.speech.with_streaming_response.create(
|
135 |
+
model='tts-1',
|
136 |
+
voice='cdteam',
|
137 |
+
input='Xin chào Việt Nam.',
|
138 |
+
speed=1.0,
|
139 |
+
response_format='wav'
|
140 |
+
) as response:
|
141 |
+
response.stream_to_file('a.wav')
|
142 |
+
```
|
143 |
+
|
144 |
+
#### CURL
|
145 |
+
```bash
|
146 |
+
# Lấy danh sách giọng có sẵn
|
147 |
+
curl --location http://0.0.0.0:8298/v1/voices
|
148 |
+
|
149 |
+
# OpenAI API format
|
150 |
+
curl http://localhost:8298/v1/audio/speech \
|
151 |
+
-H "Authorization: Bearer viet-tts" \
|
152 |
+
-H "Content-Type: application/json" \
|
153 |
+
-d '{
|
154 |
+
"model": "tts-1",
|
155 |
+
"input": "Xin chào Việt Nam.",
|
156 |
+
"voice": "son-tung-mtp"
|
157 |
+
}' \
|
158 |
+
--output speech.wav
|
159 |
+
|
160 |
+
# API với giọng từ file local
|
161 |
+
curl --location http://0.0.0.0:8298/v1/tts \
|
162 |
+
--form 'text="xin chào"' \
|
163 |
+
--form 'audio_file=@"/home/viettts/Downloads/voice.mp4"' \
|
164 |
+
--output speech.wav
|
165 |
+
```
|
166 |
+
|
167 |
+
#### Node
|
168 |
+
```js
|
169 |
+
import fs from "fs";
|
170 |
+
import path from "path";
|
171 |
+
import OpenAI from "openai";
|
172 |
+
|
173 |
+
const openai = new OpenAI();
|
174 |
+
const speechFile = path.resolve("./speech.wav");
|
175 |
+
|
176 |
+
async function main() {
|
177 |
+
const mp3 = await openai.audio.speech.create({
|
178 |
+
model: "tts-1",
|
179 |
+
voice: "1",
|
180 |
+
input: "Xin chào Việt Nam.",
|
181 |
+
});
|
182 |
+
console.log(speechFile);
|
183 |
+
const buffer = Buffer.from(await mp3.arrayBuffer());
|
184 |
+
await fs.promises.writeFile(speechFile, buffer);
|
185 |
+
}
|
186 |
+
main();
|
187 |
+
```
|
188 |
+
|
189 |
+
## 🙏 Mã liên quan
|
190 |
+
- 💡 Sử dụng mã từ [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
|
191 |
+
- 🎙️ Mô hình VAD từ [silero-vad](https://github.com/snakers4/silero-vad)
|
192 |
+
- 📝 Chuẩn hóa văn bản với [Vinorm](https://github.com/v-nhandt21/Vinorm)
|
193 |
+
|
194 |
+
## 📜 Giấy phép
|
195 |
+
Mã nguồn của **VietTTS** được cấp phép theo **Apache 2.0 License**. Mô hình và mẫu âm thanh huấn luyện được cấp phép theo **CC BY-NC License**, dựa trên tập dữ liệu từ internet. Xin lỗi nếu điều này gây bất tiện.
|
196 |
+
|
197 |
+
## ⚠️ Tuyên bố miễn trừ trách nhiệm
|
198 |
+
Nội dung trên chỉ phục vụ mục đích học thuật và nhằm trình bày khả năng kỹ thuật. Một số ví dụ lấy từ internet. Nếu nội dung vi phạm quyền của bạn, vui lòng liên hệ để được gỡ bỏ.
|
199 |
+
|
200 |
+
## 💬 Liên hệ
|
201 |
+
- Facebook: https://fb.com/sam.rngd
|
202 |
+
- GitHub: https://github.com/dangvansam
|
203 |
+
- Email: [email protected]
|
VietTTS/models/config.yaml
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__set_seed1: !apply:random.seed [1986]
|
2 |
+
__set_seed2: !apply:numpy.random.seed [1986]
|
3 |
+
__set_seed3: !apply:torch.manual_seed [1986]
|
4 |
+
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
5 |
+
|
6 |
+
sample_rate: 22050
|
7 |
+
text_encoder_input_size: 512
|
8 |
+
llm_input_size: 1024
|
9 |
+
llm_output_size: 1024
|
10 |
+
spk_embed_dim: 192
|
11 |
+
|
12 |
+
llm: !new:VietTTS.llm.llm.TransformerLM
|
13 |
+
text_encoder_input_size: !ref <text_encoder_input_size>
|
14 |
+
llm_input_size: !ref <llm_input_size>
|
15 |
+
llm_output_size: !ref <llm_output_size>
|
16 |
+
text_token_size: 60515
|
17 |
+
speech_token_size: 4096
|
18 |
+
length_normalized_loss: True
|
19 |
+
lsm_weight: 0
|
20 |
+
spk_embed_dim: !ref <spk_embed_dim>
|
21 |
+
text_encoder: !new:VietTTS.transformer.encoder.ConformerEncoder
|
22 |
+
input_size: !ref <text_encoder_input_size>
|
23 |
+
output_size: 1024
|
24 |
+
attention_heads: 16
|
25 |
+
linear_units: 4096
|
26 |
+
num_blocks: 6
|
27 |
+
dropout_rate: 0.1
|
28 |
+
positional_dropout_rate: 0.1
|
29 |
+
attention_dropout_rate: 0.0
|
30 |
+
normalize_before: True
|
31 |
+
input_layer: 'linear'
|
32 |
+
pos_enc_layer_type: 'rel_pos_espnet'
|
33 |
+
selfattention_layer_type: 'rel_selfattn'
|
34 |
+
use_cnn_module: False
|
35 |
+
macaron_style: False
|
36 |
+
use_dynamic_chunk: False
|
37 |
+
use_dynamic_left_chunk: False
|
38 |
+
static_chunk_size: 1
|
39 |
+
llm: !new:VietTTS.transformer.encoder.TransformerEncoder
|
40 |
+
input_size: !ref <llm_input_size>
|
41 |
+
output_size: !ref <llm_output_size>
|
42 |
+
attention_heads: 16
|
43 |
+
linear_units: 4096
|
44 |
+
num_blocks: 14
|
45 |
+
dropout_rate: 0.1
|
46 |
+
positional_dropout_rate: 0.1
|
47 |
+
attention_dropout_rate: 0.0
|
48 |
+
input_layer: 'linear_legacy'
|
49 |
+
pos_enc_layer_type: 'rel_pos_espnet'
|
50 |
+
selfattention_layer_type: 'rel_selfattn'
|
51 |
+
static_chunk_size: 1
|
52 |
+
sampling: !name:VietTTS.utils.common.ras_sampling
|
53 |
+
top_p: 0.8
|
54 |
+
top_k: 25
|
55 |
+
win_size: 10
|
56 |
+
tau_r: 0.1
|
57 |
+
|
58 |
+
flow: !new:VietTTS.flow.flow.MaskedDiffWithXvec
|
59 |
+
input_size: 512
|
60 |
+
output_size: 80
|
61 |
+
spk_embed_dim: !ref <spk_embed_dim>
|
62 |
+
output_type: 'mel'
|
63 |
+
vocab_size: 4096
|
64 |
+
input_frame_rate: 25
|
65 |
+
only_mask_loss: True
|
66 |
+
encoder: !new:VietTTS.transformer.encoder.ConformerEncoder
|
67 |
+
output_size: 512
|
68 |
+
attention_heads: 8
|
69 |
+
linear_units: 2048
|
70 |
+
num_blocks: 6
|
71 |
+
dropout_rate: 0.1
|
72 |
+
positional_dropout_rate: 0.1
|
73 |
+
attention_dropout_rate: 0.1
|
74 |
+
normalize_before: True
|
75 |
+
input_layer: 'linear'
|
76 |
+
pos_enc_layer_type: 'rel_pos_espnet'
|
77 |
+
selfattention_layer_type: 'rel_selfattn'
|
78 |
+
input_size: 512
|
79 |
+
use_cnn_module: False
|
80 |
+
macaron_style: False
|
81 |
+
length_regulator: !new:VietTTS.flow.length_regulator.InterpolateRegulator
|
82 |
+
channels: 80
|
83 |
+
sampling_ratios: [1, 1, 1, 1]
|
84 |
+
decoder: !new:VietTTS.flow.flow_matching.ConditionalCFM
|
85 |
+
in_channels: 240
|
86 |
+
n_spks: 1
|
87 |
+
spk_emb_dim: 80
|
88 |
+
cfm_params: !new:omegaconf.DictConfig
|
89 |
+
content:
|
90 |
+
sigma_min: 1e-06
|
91 |
+
solver: 'euler'
|
92 |
+
t_scheduler: 'cosine'
|
93 |
+
training_cfg_rate: 0.2
|
94 |
+
inference_cfg_rate: 0.7
|
95 |
+
reg_loss_type: 'l1'
|
96 |
+
estimator: !new:VietTTS.flow.decoder.ConditionalDecoder
|
97 |
+
in_channels: 320
|
98 |
+
out_channels: 80
|
99 |
+
channels: [256, 256]
|
100 |
+
dropout: 0.0
|
101 |
+
attention_head_dim: 64
|
102 |
+
n_blocks: 4
|
103 |
+
num_mid_blocks: 12
|
104 |
+
num_heads: 8
|
105 |
+
act_fn: 'gelu'
|
106 |
+
|
107 |
+
hift: !new:VietTTS.hifigan.generator.HiFTGenerator
|
108 |
+
in_channels: 80
|
109 |
+
base_channels: 512
|
110 |
+
nb_harmonics: 8
|
111 |
+
sampling_rate: !ref <sample_rate>
|
112 |
+
nsf_alpha: 0.1
|
113 |
+
nsf_sigma: 0.003
|
114 |
+
nsf_voiced_threshold: 10
|
115 |
+
upsample_rates: [8, 8]
|
116 |
+
upsample_kernel_sizes: [16, 16]
|
117 |
+
istft_params:
|
118 |
+
n_fft: 16
|
119 |
+
hop_len: 4
|
120 |
+
resblock_kernel_sizes: [3, 7, 11]
|
121 |
+
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
122 |
+
source_resblock_kernel_sizes: [7, 11]
|
123 |
+
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
124 |
+
lrelu_slope: 0.1
|
125 |
+
audio_limit: 0.99
|
126 |
+
f0_predictor: !new:VietTTS.hifigan.f0_predictor.ConvRNNF0Predictor
|
127 |
+
num_class: 1
|
128 |
+
in_channels: 80
|
129 |
+
cond_channels: 512
|
VietTTS/models/flow.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1411de192039a21d53f0bf1968feb50586ce71d81ea1443f8163f4d1c46c5455
|
3 |
+
size 419901370
|
VietTTS/models/hift.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91e679b6ca1eff71187ffb4f3ab0444935594cdcc20a9bd12afad111ef8d6012
|
3 |
+
size 81896716
|
VietTTS/models/llm.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1773e5afe16a88ee82e33cf510a07717ce1346d2e74856733d72dc297a9a017
|
3 |
+
size 1260740644
|
VietTTS/models/speech_embedding.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
|
3 |
+
size 28303423
|
VietTTS/models/speech_tokenizer.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486
|
3 |
+
size 522625011
|
VietTTS/samples/cdteam.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6adf2c56a4dabcbfcc427df36f9dd268efb1153881b682071a80ad80ae4f0ac5
|
3 |
+
size 1290116
|
VietTTS/samples/cross_lingual_prompt.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:353a7715c2e4811f4045658b29d1ce67ecad5120e09de10ce890f1763aab486c
|
3 |
+
size 606404
|
VietTTS/samples/diep-chi.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af5ae833e3d2213c09704d83535d5416744c8372368edc1005a2587c631c87ea
|
3 |
+
size 1272260
|
VietTTS/samples/doremon.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c738f39e46db361e2e02f9064ef736c75b7dcf145873682619c123259b04762
|
3 |
+
size 761386
|
VietTTS/samples/jack-sparrow.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8729fd534af6f354c39bdb90cfa352654f876eaa0ad5759bf617797c9388878c
|
3 |
+
size 177121
|
VietTTS/samples/nguyen-ngoc-ngan.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b9ca5b01b44fdd2be7416644fbc7d463248405554fff24cf1eaaed93bd31cea
|
3 |
+
size 5351668
|
VietTTS/samples/nsnd-le-chuc.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:08d620e721295afdba2cf3d9d4e772f10cd5b416ef14c8d11284431657deeb97
|
3 |
+
size 1416881
|
VietTTS/samples/nu-nhe-nhang.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a120871b168489a33b7f3188764b0f973583bf5284bd96cd805d9e6256a7e45
|
3 |
+
size 710734
|
VietTTS/samples/quynh.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7c5ff2187ca7a5e1371e4ab48cffdadbc38684da8ac3bcae598122ef294401f
|
3 |
+
size 2178450
|
VietTTS/samples/son-tung-mtp.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5b22e5beb4e71b5405f7839656c4d5d71fc34f03b65a58ab27eb86a7f3dfe52
|
3 |
+
size 1473048
|