duyv commited on
Commit
a257816
·
verified ·
1 Parent(s): 898516b

Upload 86 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +23 -0
  2. VietTTS/__init__.py +0 -0
  3. VietTTS/cli.py +114 -0
  4. VietTTS/flow/decoder.py +649 -0
  5. VietTTS/flow/flow.py +158 -0
  6. VietTTS/flow/flow_matching.py +268 -0
  7. VietTTS/flow/length_regulator.py +56 -0
  8. VietTTS/frontend.py +151 -0
  9. VietTTS/hifigan/f0_predictor.py +42 -0
  10. VietTTS/hifigan/generator.py +384 -0
  11. VietTTS/llm/llm.py +199 -0
  12. VietTTS/model.py +260 -0
  13. VietTTS/models/.cache/huggingface/.gitignore +1 -0
  14. VietTTS/models/.cache/huggingface/download/.gitattributes.lock +0 -0
  15. VietTTS/models/.cache/huggingface/download/.gitattributes.metadata +3 -0
  16. VietTTS/models/.cache/huggingface/download/README.md.lock +0 -0
  17. VietTTS/models/.cache/huggingface/download/README.md.metadata +3 -0
  18. VietTTS/models/.cache/huggingface/download/README_VN.md.lock +0 -0
  19. VietTTS/models/.cache/huggingface/download/README_VN.md.metadata +3 -0
  20. VietTTS/models/.cache/huggingface/download/config.yaml.lock +0 -0
  21. VietTTS/models/.cache/huggingface/download/config.yaml.metadata +3 -0
  22. VietTTS/models/.cache/huggingface/download/flow.pt.lock +0 -0
  23. VietTTS/models/.cache/huggingface/download/flow.pt.metadata +3 -0
  24. VietTTS/models/.cache/huggingface/download/hift.pt.lock +0 -0
  25. VietTTS/models/.cache/huggingface/download/hift.pt.metadata +3 -0
  26. VietTTS/models/.cache/huggingface/download/llm.pt.lock +0 -0
  27. VietTTS/models/.cache/huggingface/download/llm.pt.metadata +3 -0
  28. VietTTS/models/.cache/huggingface/download/speech_embedding.onnx.lock +0 -0
  29. VietTTS/models/.cache/huggingface/download/speech_embedding.onnx.metadata +3 -0
  30. VietTTS/models/.cache/huggingface/download/speech_tokenizer.onnx.lock +0 -0
  31. VietTTS/models/.cache/huggingface/download/speech_tokenizer.onnx.metadata +3 -0
  32. VietTTS/models/.gitattributes +35 -0
  33. VietTTS/models/README.md +213 -0
  34. VietTTS/models/README_VN.md +203 -0
  35. VietTTS/models/config.yaml +129 -0
  36. VietTTS/models/flow.pt +3 -0
  37. VietTTS/models/hift.pt +3 -0
  38. VietTTS/models/llm.pt +3 -0
  39. VietTTS/models/speech_embedding.onnx +3 -0
  40. VietTTS/models/speech_tokenizer.onnx +3 -0
  41. VietTTS/samples/cdteam.wav +3 -0
  42. VietTTS/samples/cross_lingual_prompt.wav +3 -0
  43. VietTTS/samples/diep-chi.wav +3 -0
  44. VietTTS/samples/doremon.mp3 +3 -0
  45. VietTTS/samples/jack-sparrow.mp3 +3 -0
  46. VietTTS/samples/nguyen-ngoc-ngan.wav +3 -0
  47. VietTTS/samples/nsnd-le-chuc.mp3 +3 -0
  48. VietTTS/samples/nu-nhe-nhang.wav +3 -0
  49. VietTTS/samples/quynh.wav +3 -0
  50. VietTTS/samples/son-tung-mtp.wav +3 -0
.gitattributes CHANGED
@@ -60,3 +60,26 @@ Vinorm/vinorm/lib/libicuuc.so filter=lfs diff=lfs merge=lfs -text
60
  Vinorm/vinorm/lib/libicuuc.so.64 filter=lfs diff=lfs merge=lfs -text
61
  Vinorm/vinorm/lib/libicuuc.so.64.2 filter=lfs diff=lfs merge=lfs -text
62
  Vinorm/vinorm/main filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  Vinorm/vinorm/lib/libicuuc.so.64 filter=lfs diff=lfs merge=lfs -text
61
  Vinorm/vinorm/lib/libicuuc.so.64.2 filter=lfs diff=lfs merge=lfs -text
62
  Vinorm/vinorm/main filter=lfs diff=lfs merge=lfs -text
63
+ VietTTS/samples/cdteam.wav filter=lfs diff=lfs merge=lfs -text
64
+ VietTTS/samples/cross_lingual_prompt.wav filter=lfs diff=lfs merge=lfs -text
65
+ VietTTS/samples/diep-chi.wav filter=lfs diff=lfs merge=lfs -text
66
+ VietTTS/samples/doremon.mp3 filter=lfs diff=lfs merge=lfs -text
67
+ VietTTS/samples/jack-sparrow.mp3 filter=lfs diff=lfs merge=lfs -text
68
+ VietTTS/samples/nguyen-ngoc-ngan.wav filter=lfs diff=lfs merge=lfs -text
69
+ VietTTS/samples/nsnd-le-chuc.mp3 filter=lfs diff=lfs merge=lfs -text
70
+ VietTTS/samples/nu-nhe-nhang.wav filter=lfs diff=lfs merge=lfs -text
71
+ VietTTS/samples/quynh.wav filter=lfs diff=lfs merge=lfs -text
72
+ VietTTS/samples/son-tung-mtp.wav filter=lfs diff=lfs merge=lfs -text
73
+ VietTTS/samples/speechify_1.wav filter=lfs diff=lfs merge=lfs -text
74
+ VietTTS/samples/speechify_10.wav filter=lfs diff=lfs merge=lfs -text
75
+ VietTTS/samples/speechify_11.wav filter=lfs diff=lfs merge=lfs -text
76
+ VietTTS/samples/speechify_12.wav filter=lfs diff=lfs merge=lfs -text
77
+ VietTTS/samples/speechify_2.wav filter=lfs diff=lfs merge=lfs -text
78
+ VietTTS/samples/speechify_3.wav filter=lfs diff=lfs merge=lfs -text
79
+ VietTTS/samples/speechify_4.wav filter=lfs diff=lfs merge=lfs -text
80
+ VietTTS/samples/speechify_5.wav filter=lfs diff=lfs merge=lfs -text
81
+ VietTTS/samples/speechify_6.wav filter=lfs diff=lfs merge=lfs -text
82
+ VietTTS/samples/speechify_7.wav filter=lfs diff=lfs merge=lfs -text
83
+ VietTTS/samples/speechify_8.wav filter=lfs diff=lfs merge=lfs -text
84
+ VietTTS/samples/speechify_9.wav filter=lfs diff=lfs merge=lfs -text
85
+ VietTTS/samples/zero_shot_prompt.wav filter=lfs diff=lfs merge=lfs -text
VietTTS/__init__.py ADDED
File without changes
VietTTS/cli.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import click
5
+ import subprocess
6
+ from loguru import logger
7
+ from rich.table import Table
8
+ from rich.console import Console
9
+ from VietTTS.tts import TTS
10
+ from VietTTS.utils.file_utils import load_prompt_speech_from_file, load_voices
11
+
12
+
13
+ AUDIO_DIR = 'samples'
14
+ MODEL_DIR = 'pretrained-models'
15
+
16
+ @click.command('server')
17
+ @click.option('-h', '--host', type=str, default='0.0.0.0', help="The host address to bind the server to. Default is '0.0.0.0'.")
18
+ @click.option('-p', '--port', type=int, default=8298, help="The port number to bind the server to. Default is 8298.")
19
+ @click.option('-w', '--workers', type=int, default=1, help="The number of worker processes to handle requests. Default is 1.")
20
+ def start_server(host: str, port: int, workers: int):
21
+ """Start API server (OpenAI TTS API compatible).
22
+
23
+ Usage: viettts server --host 0.0.0.0 --port 8298 -w 4
24
+ """
25
+ logger.info("Starting server")
26
+ cmd = f'gunicorn viettts.server:app \
27
+ -k uvicorn.workers.UvicornWorker \
28
+ --bind {host}:{port} \
29
+ --workers {workers} \
30
+ --max-requests 1000 \
31
+ --max-requests-jitter 50 \
32
+ --timeout 300 \
33
+ --keep-alive 75 \
34
+ --graceful-timeout 60'
35
+
36
+ subprocess.call(cmd, shell=True, stdout=sys.stdout)
37
+
38
+
39
+ @click.command('synthesis')
40
+ @click.option('-t', "--text", type=str, required=True, help="The input text to synthesize into speech.")
41
+ @click.option('-v', "--voice", type=str, default='1', help="The voice ID or file path to clone the voice from. Default is '1'.")
42
+ @click.option('-s', "--speed", type=float, default=1, help="The speed multiplier for the speech. Default is 1 (normal speed).")
43
+ @click.option('-o', "--output", type=str, default='output.wav', help="The file path to save the synthesized audio. Default is 'output.wav'.")
44
+ def synthesis(text: str, voice: str, speed: float, output: str):
45
+ """Synthesis audio from text and save to file.
46
+
47
+ Usage: viettts synthesis --text 'Xin chào VietTTS' --voice nu-nhe-nhang --voice 8 --speed 1.2 --output test_nu-nhe-nhang.wav
48
+ """
49
+ logger.info("Starting synthesis")
50
+ st = time.perf_counter()
51
+ if not text:
52
+ logger.error('text must not empty')
53
+ return
54
+
55
+ if speed > 2 or speed < 0.5:
56
+ logger.error(f'speed must in range 0.5-2.0')
57
+ return
58
+
59
+ if not os.path.exists(voice):
60
+ voice_map = load_voices(AUDIO_DIR)
61
+ if voice.isdigit():
62
+ voice = list(voice_map.values())[int(voice)]
63
+ else:
64
+ voice = voice_map.get(voice)
65
+
66
+ if not os.path.exists(voice):
67
+ logger.error(f'voice is not available. Use --voice <voice-name/voice-id/local-file> or run `viettts show-voices` to get available voices.')
68
+ return
69
+
70
+ logger.info('Loading model')
71
+ tts = TTS(model_dir=MODEL_DIR)
72
+
73
+ logger.info('Loading voice')
74
+ voice = load_prompt_speech_from_file(voice)
75
+
76
+ logger.info('Processing')
77
+ tts.tts_to_file(text, voice, speed, output)
78
+
79
+ et = time.perf_counter()
80
+ logger.success(f"Saved to: {output} [time cost={et-st:.2f}s]")
81
+
82
+
83
+ @click.command('show-voices')
84
+ def show_voice():
85
+ """Print all available voices.
86
+
87
+ Usage: viettts show-voices
88
+ """
89
+ voice_map = load_voices(AUDIO_DIR)
90
+ console = Console()
91
+ table = Table(show_header=True, header_style="green", show_lines=False)
92
+ table.add_column("Voice ID", width=10)
93
+ table.add_column("Voice Name", width=30)
94
+ table.add_column("File", justify="left")
95
+
96
+ for i, (voice_name, voice_path) in enumerate(voice_map.items()):
97
+ table.add_row(str(i+1), voice_name, voice_path)
98
+
99
+ console.print(table)
100
+
101
+
102
+ @click.group()
103
+ def cli():
104
+ """
105
+ VietTTS CLI v0.1.0
106
+
107
+ Vietnamese Text To Speech and Voice Clone
108
+ License: Apache 2.0 - Author: <dangvansam [email protected]>
109
+ """
110
+ pass
111
+
112
+ cli.add_command(start_server)
113
+ cli.add_command(synthesis)
114
+ cli.add_command(show_voice)
VietTTS/flow/decoder.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import pack, rearrange, repeat
4
+
5
+ import math
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from conformer import ConformerBlock
12
+ from diffusers.models.activations import get_activation
13
+
14
+ from VietTTS.transformer.transformer import BasicTransformerBlock
15
+
16
+
17
+ class SinusoidalPosEmb(torch.nn.Module):
18
+ def __init__(self, dim):
19
+ super().__init__()
20
+ self.dim = dim
21
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
22
+
23
+ def forward(self, x, scale=1000):
24
+ if x.ndim < 1:
25
+ x = x.unsqueeze(0)
26
+ device = x.device
27
+ half_dim = self.dim // 2
28
+ emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
30
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
31
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
32
+ return emb
33
+
34
+
35
+ class Block1D(torch.nn.Module):
36
+ def __init__(self, dim, dim_out, groups=8):
37
+ super().__init__()
38
+ self.block = torch.nn.Sequential(
39
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
40
+ torch.nn.GroupNorm(groups, dim_out),
41
+ nn.Mish(),
42
+ )
43
+
44
+ def forward(self, x, mask):
45
+ output = self.block(x * mask)
46
+ return output * mask
47
+
48
+
49
+ class ResnetBlock1D(torch.nn.Module):
50
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
51
+ super().__init__()
52
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
53
+
54
+ self.block1 = Block1D(dim, dim_out, groups=groups)
55
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
56
+
57
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
58
+
59
+ def forward(self, x, mask, time_emb):
60
+ h = self.block1(x, mask)
61
+ h += self.mlp(time_emb).unsqueeze(-1)
62
+ h = self.block2(h, mask)
63
+ output = h + self.res_conv(x * mask)
64
+ return output
65
+
66
+
67
+ class Downsample1D(nn.Module):
68
+ def __init__(self, dim):
69
+ super().__init__()
70
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
71
+
72
+ def forward(self, x):
73
+ return self.conv(x)
74
+
75
+
76
+ class TimestepEmbedding(nn.Module):
77
+ def __init__(
78
+ self,
79
+ in_channels: int,
80
+ time_embed_dim: int,
81
+ act_fn: str = "silu",
82
+ out_dim: int = None,
83
+ post_act_fn: Optional[str] = None,
84
+ cond_proj_dim=None,
85
+ ):
86
+ super().__init__()
87
+
88
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
89
+
90
+ if cond_proj_dim is not None:
91
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
92
+ else:
93
+ self.cond_proj = None
94
+
95
+ self.act = get_activation(act_fn)
96
+
97
+ if out_dim is not None:
98
+ time_embed_dim_out = out_dim
99
+ else:
100
+ time_embed_dim_out = time_embed_dim
101
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
102
+
103
+ if post_act_fn is None:
104
+ self.post_act = None
105
+ else:
106
+ self.post_act = get_activation(post_act_fn)
107
+
108
+ def forward(self, sample, condition=None):
109
+ if condition is not None:
110
+ sample = sample + self.cond_proj(condition)
111
+ sample = self.linear_1(sample)
112
+
113
+ if self.act is not None:
114
+ sample = self.act(sample)
115
+
116
+ sample = self.linear_2(sample)
117
+
118
+ if self.post_act is not None:
119
+ sample = self.post_act(sample)
120
+ return sample
121
+
122
+
123
+ class Upsample1D(nn.Module):
124
+ """A 1D upsampling layer with an optional convolution.
125
+
126
+ Parameters:
127
+ channels (`int`):
128
+ number of channels in the inputs and outputs.
129
+ use_conv (`bool`, default `False`):
130
+ option to use a convolution.
131
+ use_conv_transpose (`bool`, default `False`):
132
+ option to use a convolution transpose.
133
+ out_channels (`int`, optional):
134
+ number of output channels. Defaults to `channels`.
135
+ """
136
+
137
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
138
+ super().__init__()
139
+ self.channels = channels
140
+ self.out_channels = out_channels or channels
141
+ self.use_conv = use_conv
142
+ self.use_conv_transpose = use_conv_transpose
143
+ self.name = name
144
+
145
+ self.conv = None
146
+ if use_conv_transpose:
147
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
148
+ elif use_conv:
149
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
150
+
151
+ def forward(self, inputs):
152
+ assert inputs.shape[1] == self.channels
153
+ if self.use_conv_transpose:
154
+ return self.conv(inputs)
155
+
156
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
157
+
158
+ if self.use_conv:
159
+ outputs = self.conv(outputs)
160
+
161
+ return outputs
162
+
163
+
164
+ class ConformerWrapper(ConformerBlock):
165
+ def __init__( # pylint: disable=useless-super-delegation
166
+ self,
167
+ *,
168
+ dim,
169
+ dim_head=64,
170
+ heads=8,
171
+ ff_mult=4,
172
+ conv_expansion_factor=2,
173
+ conv_kernel_size=31,
174
+ attn_dropout=0,
175
+ ff_dropout=0,
176
+ conv_dropout=0,
177
+ conv_causal=False,
178
+ ):
179
+ super().__init__(
180
+ dim=dim,
181
+ dim_head=dim_head,
182
+ heads=heads,
183
+ ff_mult=ff_mult,
184
+ conv_expansion_factor=conv_expansion_factor,
185
+ conv_kernel_size=conv_kernel_size,
186
+ attn_dropout=attn_dropout,
187
+ ff_dropout=ff_dropout,
188
+ conv_dropout=conv_dropout,
189
+ conv_causal=conv_causal,
190
+ )
191
+
192
+ def forward(
193
+ self,
194
+ hidden_states,
195
+ attention_mask,
196
+ encoder_hidden_states=None,
197
+ encoder_attention_mask=None,
198
+ timestep=None,
199
+ ):
200
+ return super().forward(x=hidden_states, mask=attention_mask.bool())
201
+
202
+
203
+ class Decoder(nn.Module):
204
+ def __init__(
205
+ self,
206
+ in_channels,
207
+ out_channels,
208
+ channels=(256, 256),
209
+ dropout=0.05,
210
+ attention_head_dim=64,
211
+ n_blocks=1,
212
+ num_mid_blocks=2,
213
+ num_heads=4,
214
+ act_fn="snake",
215
+ down_block_type="transformer",
216
+ mid_block_type="transformer",
217
+ up_block_type="transformer",
218
+ ):
219
+ super().__init__()
220
+ channels = tuple(channels)
221
+ self.in_channels = in_channels
222
+ self.out_channels = out_channels
223
+
224
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
225
+ time_embed_dim = channels[0] * 4
226
+ self.time_mlp = TimestepEmbedding(
227
+ in_channels=in_channels,
228
+ time_embed_dim=time_embed_dim,
229
+ act_fn="silu",
230
+ )
231
+
232
+ self.down_blocks = nn.ModuleList([])
233
+ self.mid_blocks = nn.ModuleList([])
234
+ self.up_blocks = nn.ModuleList([])
235
+
236
+ output_channel = in_channels
237
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
238
+ input_channel = output_channel
239
+ output_channel = channels[i]
240
+ is_last = i == len(channels) - 1
241
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
242
+ transformer_blocks = nn.ModuleList(
243
+ [
244
+ self.get_block(
245
+ down_block_type,
246
+ output_channel,
247
+ attention_head_dim,
248
+ num_heads,
249
+ dropout,
250
+ act_fn,
251
+ )
252
+ for _ in range(n_blocks)
253
+ ]
254
+ )
255
+ downsample = (
256
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
257
+ )
258
+
259
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
260
+
261
+ for i in range(num_mid_blocks):
262
+ input_channel = channels[-1]
263
+ out_channels = channels[-1]
264
+
265
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
266
+
267
+ transformer_blocks = nn.ModuleList(
268
+ [
269
+ self.get_block(
270
+ mid_block_type,
271
+ output_channel,
272
+ attention_head_dim,
273
+ num_heads,
274
+ dropout,
275
+ act_fn,
276
+ )
277
+ for _ in range(n_blocks)
278
+ ]
279
+ )
280
+
281
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
282
+
283
+ channels = channels[::-1] + (channels[0],)
284
+ for i in range(len(channels) - 1):
285
+ input_channel = channels[i]
286
+ output_channel = channels[i + 1]
287
+ is_last = i == len(channels) - 2
288
+
289
+ resnet = ResnetBlock1D(
290
+ dim=2 * input_channel,
291
+ dim_out=output_channel,
292
+ time_emb_dim=time_embed_dim,
293
+ )
294
+ transformer_blocks = nn.ModuleList(
295
+ [
296
+ self.get_block(
297
+ up_block_type,
298
+ output_channel,
299
+ attention_head_dim,
300
+ num_heads,
301
+ dropout,
302
+ act_fn,
303
+ )
304
+ for _ in range(n_blocks)
305
+ ]
306
+ )
307
+ upsample = (
308
+ Upsample1D(output_channel, use_conv_transpose=True)
309
+ if not is_last
310
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
311
+ )
312
+
313
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
314
+
315
+ self.final_block = Block1D(channels[-1], channels[-1])
316
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
317
+
318
+ self.initialize_weights()
319
+ # nn.init.normal_(self.final_proj.weight)
320
+
321
+ @staticmethod
322
+ def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
323
+ if block_type == "conformer":
324
+ block = ConformerWrapper(
325
+ dim=dim,
326
+ dim_head=attention_head_dim,
327
+ heads=num_heads,
328
+ ff_mult=1,
329
+ conv_expansion_factor=2,
330
+ ff_dropout=dropout,
331
+ attn_dropout=dropout,
332
+ conv_dropout=dropout,
333
+ conv_kernel_size=31,
334
+ )
335
+ elif block_type == "transformer":
336
+ block = BasicTransformerBlock(
337
+ dim=dim,
338
+ num_attention_heads=num_heads,
339
+ attention_head_dim=attention_head_dim,
340
+ dropout=dropout,
341
+ activation_fn=act_fn,
342
+ )
343
+ else:
344
+ raise ValueError(f"Unknown block type {block_type}")
345
+
346
+ return block
347
+
348
+ def initialize_weights(self):
349
+ for m in self.modules():
350
+ if isinstance(m, nn.Conv1d):
351
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
352
+
353
+ if m.bias is not None:
354
+ nn.init.constant_(m.bias, 0)
355
+
356
+ elif isinstance(m, nn.GroupNorm):
357
+ nn.init.constant_(m.weight, 1)
358
+ nn.init.constant_(m.bias, 0)
359
+
360
+ elif isinstance(m, nn.Linear):
361
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
362
+
363
+ if m.bias is not None:
364
+ nn.init.constant_(m.bias, 0)
365
+
366
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
367
+ """Forward pass of the UNet1DConditional model.
368
+
369
+ Args:
370
+ x (torch.Tensor): shape (batch_size, in_channels, time)
371
+ mask (_type_): shape (batch_size, 1, time)
372
+ t (_type_): shape (batch_size)
373
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
374
+ cond (_type_, optional): placeholder for future use. Defaults to None.
375
+
376
+ Raises:
377
+ ValueError: _description_
378
+ ValueError: _description_
379
+
380
+ Returns:
381
+ _type_: _description_
382
+ """
383
+
384
+ t = self.time_embeddings(t)
385
+ t = self.time_mlp(t)
386
+
387
+ x = pack([x, mu], "b * t")[0]
388
+
389
+ if spks is not None:
390
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
391
+ x = pack([x, spks], "b * t")[0]
392
+
393
+ hiddens = []
394
+ masks = [mask]
395
+ for resnet, transformer_blocks, downsample in self.down_blocks:
396
+ mask_down = masks[-1]
397
+ x = resnet(x, mask_down, t)
398
+ x = rearrange(x, "b c t -> b t c")
399
+ mask_down = rearrange(mask_down, "b 1 t -> b t")
400
+ for transformer_block in transformer_blocks:
401
+ x = transformer_block(
402
+ hidden_states=x,
403
+ attention_mask=mask_down,
404
+ timestep=t,
405
+ )
406
+ x = rearrange(x, "b t c -> b c t")
407
+ mask_down = rearrange(mask_down, "b t -> b 1 t")
408
+ hiddens.append(x) # Save hidden states for skip connections
409
+ x = downsample(x * mask_down)
410
+ masks.append(mask_down[:, :, ::2])
411
+
412
+ masks = masks[:-1]
413
+ mask_mid = masks[-1]
414
+
415
+ for resnet, transformer_blocks in self.mid_blocks:
416
+ x = resnet(x, mask_mid, t)
417
+ x = rearrange(x, "b c t -> b t c")
418
+ mask_mid = rearrange(mask_mid, "b 1 t -> b t")
419
+ for transformer_block in transformer_blocks:
420
+ x = transformer_block(
421
+ hidden_states=x,
422
+ attention_mask=mask_mid,
423
+ timestep=t,
424
+ )
425
+ x = rearrange(x, "b t c -> b c t")
426
+ mask_mid = rearrange(mask_mid, "b t -> b 1 t")
427
+
428
+ for resnet, transformer_blocks, upsample in self.up_blocks:
429
+ mask_up = masks.pop()
430
+ x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
431
+ x = rearrange(x, "b c t -> b t c")
432
+ mask_up = rearrange(mask_up, "b 1 t -> b t")
433
+ for transformer_block in transformer_blocks:
434
+ x = transformer_block(
435
+ hidden_states=x,
436
+ attention_mask=mask_up,
437
+ timestep=t,
438
+ )
439
+ x = rearrange(x, "b t c -> b c t")
440
+ mask_up = rearrange(mask_up, "b t -> b 1 t")
441
+ x = upsample(x * mask_up)
442
+
443
+ x = self.final_block(x, mask_up)
444
+ output = self.final_proj(x * mask_up)
445
+
446
+ return output * mask
447
+
448
+
449
+ class ConditionalDecoder(nn.Module):
450
+ def __init__(
451
+ self,
452
+ in_channels,
453
+ out_channels,
454
+ channels=(256, 256),
455
+ dropout=0.05,
456
+ attention_head_dim=64,
457
+ n_blocks=1,
458
+ num_mid_blocks=2,
459
+ num_heads=4,
460
+ act_fn="snake",
461
+ ):
462
+ """
463
+ This decoder requires an input with the same shape of the target. So, if your text content
464
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
465
+ """
466
+ super().__init__()
467
+ channels = tuple(channels)
468
+ self.in_channels = in_channels
469
+ self.out_channels = out_channels
470
+
471
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
472
+ time_embed_dim = channels[0] * 4
473
+ self.time_mlp = TimestepEmbedding(
474
+ in_channels=in_channels,
475
+ time_embed_dim=time_embed_dim,
476
+ act_fn="silu",
477
+ )
478
+ self.down_blocks = nn.ModuleList([])
479
+ self.mid_blocks = nn.ModuleList([])
480
+ self.up_blocks = nn.ModuleList([])
481
+
482
+ output_channel = in_channels
483
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
484
+ input_channel = output_channel
485
+ output_channel = channels[i]
486
+ is_last = i == len(channels) - 1
487
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
488
+ transformer_blocks = nn.ModuleList(
489
+ [
490
+ BasicTransformerBlock(
491
+ dim=output_channel,
492
+ num_attention_heads=num_heads,
493
+ attention_head_dim=attention_head_dim,
494
+ dropout=dropout,
495
+ activation_fn=act_fn,
496
+ )
497
+ for _ in range(n_blocks)
498
+ ]
499
+ )
500
+ downsample = (
501
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
502
+ )
503
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
504
+
505
+ for _ in range(num_mid_blocks):
506
+ input_channel = channels[-1]
507
+ out_channels = channels[-1]
508
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
509
+
510
+ transformer_blocks = nn.ModuleList(
511
+ [
512
+ BasicTransformerBlock(
513
+ dim=output_channel,
514
+ num_attention_heads=num_heads,
515
+ attention_head_dim=attention_head_dim,
516
+ dropout=dropout,
517
+ activation_fn=act_fn,
518
+ )
519
+ for _ in range(n_blocks)
520
+ ]
521
+ )
522
+
523
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
524
+
525
+ channels = channels[::-1] + (channels[0],)
526
+ for i in range(len(channels) - 1):
527
+ input_channel = channels[i] * 2
528
+ output_channel = channels[i + 1]
529
+ is_last = i == len(channels) - 2
530
+ resnet = ResnetBlock1D(
531
+ dim=input_channel,
532
+ dim_out=output_channel,
533
+ time_emb_dim=time_embed_dim,
534
+ )
535
+ transformer_blocks = nn.ModuleList(
536
+ [
537
+ BasicTransformerBlock(
538
+ dim=output_channel,
539
+ num_attention_heads=num_heads,
540
+ attention_head_dim=attention_head_dim,
541
+ dropout=dropout,
542
+ activation_fn=act_fn,
543
+ )
544
+ for _ in range(n_blocks)
545
+ ]
546
+ )
547
+ upsample = (
548
+ Upsample1D(output_channel, use_conv_transpose=True)
549
+ if not is_last
550
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
551
+ )
552
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
553
+ self.final_block = Block1D(channels[-1], channels[-1])
554
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
555
+ self.initialize_weights()
556
+
557
+ def initialize_weights(self):
558
+ for m in self.modules():
559
+ if isinstance(m, nn.Conv1d):
560
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
561
+ if m.bias is not None:
562
+ nn.init.constant_(m.bias, 0)
563
+ elif isinstance(m, nn.GroupNorm):
564
+ nn.init.constant_(m.weight, 1)
565
+ nn.init.constant_(m.bias, 0)
566
+ elif isinstance(m, nn.Linear):
567
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
568
+ if m.bias is not None:
569
+ nn.init.constant_(m.bias, 0)
570
+
571
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
572
+ """Forward pass of the UNet1DConditional model.
573
+
574
+ Args:
575
+ x (torch.Tensor): shape (batch_size, in_channels, time)
576
+ mask (_type_): shape (batch_size, 1, time)
577
+ t (_type_): shape (batch_size)
578
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
579
+ cond (_type_, optional): placeholder for future use. Defaults to None.
580
+
581
+ Raises:
582
+ ValueError: _description_
583
+ ValueError: _description_
584
+
585
+ Returns:
586
+ _type_: _description_
587
+ """
588
+
589
+ t = self.time_embeddings(t).to(t.dtype)
590
+ t = self.time_mlp(t)
591
+
592
+ x = pack([x, mu], "b * t")[0]
593
+
594
+ if spks is not None:
595
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
596
+ x = pack([x, spks], "b * t")[0]
597
+ if cond is not None:
598
+ x = pack([x, cond], "b * t")[0]
599
+
600
+ hiddens = []
601
+ masks = [mask]
602
+ for resnet, transformer_blocks, downsample in self.down_blocks:
603
+ mask_down = masks[-1]
604
+ x = resnet(x, mask_down, t)
605
+ x = rearrange(x, "b c t -> b t c").contiguous()
606
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
607
+ for transformer_block in transformer_blocks:
608
+ x = transformer_block(
609
+ hidden_states=x,
610
+ attention_mask=attn_mask,
611
+ timestep=t,
612
+ )
613
+ x = rearrange(x, "b t c -> b c t").contiguous()
614
+ hiddens.append(x) # Save hidden states for skip connections
615
+ x = downsample(x * mask_down)
616
+ masks.append(mask_down[:, :, ::2])
617
+ masks = masks[:-1]
618
+ mask_mid = masks[-1]
619
+
620
+ for resnet, transformer_blocks in self.mid_blocks:
621
+ x = resnet(x, mask_mid, t)
622
+ x = rearrange(x, "b c t -> b t c").contiguous()
623
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
624
+ for transformer_block in transformer_blocks:
625
+ x = transformer_block(
626
+ hidden_states=x,
627
+ attention_mask=attn_mask,
628
+ timestep=t,
629
+ )
630
+ x = rearrange(x, "b t c -> b c t").contiguous()
631
+
632
+ for resnet, transformer_blocks, upsample in self.up_blocks:
633
+ mask_up = masks.pop()
634
+ skip = hiddens.pop()
635
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
636
+ x = resnet(x, mask_up, t)
637
+ x = rearrange(x, "b c t -> b t c").contiguous()
638
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
639
+ for transformer_block in transformer_blocks:
640
+ x = transformer_block(
641
+ hidden_states=x,
642
+ attention_mask=attn_mask,
643
+ timestep=t,
644
+ )
645
+ x = rearrange(x, "b t c -> b c t").contiguous()
646
+ x = upsample(x * mask_up)
647
+ x = self.final_block(x, mask_up)
648
+ output = self.final_proj(x * mask_up)
649
+ return output * mask
VietTTS/flow/flow.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import Dict, Optional
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ from omegaconf import DictConfig
8
+ from VietTTS.utils.mask import make_pad_mask
9
+
10
+
11
+ class MaskedDiffWithXvec(torch.nn.Module):
12
+ def __init__(self,
13
+ input_size: int = 512,
14
+ output_size: int = 80,
15
+ spk_embed_dim: int = 192,
16
+ output_type: str = "mel",
17
+ vocab_size: int = 4096,
18
+ input_frame_rate: int = 50,
19
+ only_mask_loss: bool = True,
20
+ encoder: torch.nn.Module = None,
21
+ length_regulator: torch.nn.Module = None,
22
+ decoder: torch.nn.Module = None,
23
+ decoder_conf: Dict = {
24
+ 'in_channels': 240,
25
+ 'out_channel': 80,
26
+ 'spk_emb_dim': 80,
27
+ 'n_spks': 1,
28
+ 'cfm_params': DictConfig({
29
+ 'sigma_min': 1e-06,
30
+ 'solver': 'euler',
31
+ 't_scheduler': 'cosine',
32
+ 'training_cfg_rate': 0.2,
33
+ 'inference_cfg_rate': 0.7,
34
+ 'reg_loss_type': 'l1'
35
+ }),
36
+ 'decoder_params': {
37
+ 'channels': [256, 256],
38
+ 'dropout': 0.0,
39
+ 'attention_head_dim': 64,
40
+ 'n_blocks': 4,
41
+ 'num_mid_blocks': 12,
42
+ 'num_heads': 8,
43
+ 'act_fn': 'gelu'
44
+ }
45
+ },
46
+ mel_feat_conf: Dict = {
47
+ 'n_fft': 1024,
48
+ 'num_mels': 80,
49
+ 'sampling_rate': 22050,
50
+ 'hop_size': 256,
51
+ 'win_size': 1024,
52
+ 'fmin': 0,
53
+ 'fmax': 8000
54
+ }
55
+ ):
56
+ super().__init__()
57
+ self.input_size = input_size
58
+ self.output_size = output_size
59
+ self.decoder_conf = decoder_conf
60
+ self.mel_feat_conf = mel_feat_conf
61
+ self.vocab_size = vocab_size
62
+ self.output_type = output_type
63
+ self.input_frame_rate = input_frame_rate
64
+ logging.info(f"input frame rate={self.input_frame_rate}")
65
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
66
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
67
+ self.encoder = encoder
68
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
69
+ self.decoder = decoder
70
+ self.length_regulator = length_regulator
71
+ self.only_mask_loss = only_mask_loss
72
+
73
+ def forward(
74
+ self,
75
+ batch: dict,
76
+ device: torch.device,
77
+ ) -> Dict[str, Optional[torch.Tensor]]:
78
+ token = batch['speech_token'].to(device)
79
+ token_len = batch['speech_token_len'].to(device)
80
+ feat = batch['speech_feat'].to(device)
81
+ feat_len = batch['speech_feat_len'].to(device)
82
+ embedding = batch['embedding'].to(device)
83
+
84
+ # xvec projection
85
+ embedding = F.normalize(embedding, dim=1)
86
+ embedding = self.spk_embed_affine_layer(embedding)
87
+
88
+ # concat text and prompt_text
89
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
90
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
91
+
92
+ # text encode
93
+ h, h_lengths = self.encoder(token, token_len)
94
+ h = self.encoder_proj(h)
95
+ h, h_lengths = self.length_regulator(h, feat_len)
96
+
97
+ # get conditions
98
+ conds = torch.zeros(feat.shape, device=token.device)
99
+ for i, j in enumerate(feat_len):
100
+ if random.random() < 0.5:
101
+ continue
102
+ index = random.randint(0, int(0.3 * j))
103
+ conds[i, :index] = feat[i, :index]
104
+ conds = conds.transpose(1, 2)
105
+
106
+ mask = (~make_pad_mask(feat_len)).to(h)
107
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
108
+ loss, _ = self.decoder.compute_loss(
109
+ feat.transpose(1, 2).contiguous(),
110
+ mask.unsqueeze(1),
111
+ h.transpose(1, 2).contiguous(),
112
+ embedding,
113
+ cond=conds
114
+ )
115
+ return {'loss': loss}
116
+
117
+ @torch.inference_mode()
118
+ def inference(self,
119
+ token,
120
+ token_len,
121
+ prompt_token,
122
+ prompt_token_len,
123
+ prompt_feat,
124
+ prompt_feat_len,
125
+ embedding):
126
+ assert token.shape[0] == 1
127
+ # xvec projection
128
+ embedding = F.normalize(embedding, dim=1)
129
+ embedding = self.spk_embed_affine_layer(embedding)
130
+
131
+ # concat text and prompt_text
132
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
133
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
134
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
135
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
136
+
137
+ # text encode
138
+ h, h_lengths = self.encoder(token, token_len)
139
+ h = self.encoder_proj(h)
140
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
141
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
142
+
143
+ # get conditions
144
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
145
+ conds[:, :mel_len1] = prompt_feat
146
+ conds = conds.transpose(1, 2)
147
+
148
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
149
+ feat = self.decoder(
150
+ mu=h.transpose(1, 2).contiguous(),
151
+ mask=mask.unsqueeze(1),
152
+ spks=embedding,
153
+ cond=conds,
154
+ n_timesteps=10
155
+ )
156
+ feat = feat[:, :, mel_len1:]
157
+ assert feat.shape[2] == mel_len2
158
+ return feat
VietTTS/flow/flow_matching.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from VietTTS.flow.decoder import Decoder
7
+
8
+
9
+ class BASECFM(torch.nn.Module, ABC):
10
+ def __init__(
11
+ self,
12
+ n_feats,
13
+ cfm_params,
14
+ n_spks=1,
15
+ spk_emb_dim=128,
16
+ ):
17
+ super().__init__()
18
+ self.n_feats = n_feats
19
+ self.n_spks = n_spks
20
+ self.spk_emb_dim = spk_emb_dim
21
+ self.solver = cfm_params.solver
22
+ if hasattr(cfm_params, "sigma_min"):
23
+ self.sigma_min = cfm_params.sigma_min
24
+ else:
25
+ self.sigma_min = 1e-4
26
+
27
+ self.estimator = None
28
+
29
+ @torch.inference_mode()
30
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
31
+ """Forward diffusion
32
+
33
+ Args:
34
+ mu (torch.Tensor): output of encoder
35
+ shape: (batch_size, n_feats, mel_timesteps)
36
+ mask (torch.Tensor): output_mask
37
+ shape: (batch_size, 1, mel_timesteps)
38
+ n_timesteps (int): number of diffusion steps
39
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
40
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
41
+ shape: (batch_size, spk_emb_dim)
42
+ cond: Not used but kept for future purposes
43
+
44
+ Returns:
45
+ sample: generated mel-spectrogram
46
+ shape: (batch_size, n_feats, mel_timesteps)
47
+ """
48
+ z = torch.randn_like(mu) * temperature
49
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
50
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
51
+
52
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
53
+ """
54
+ Fixed euler solver for ODEs.
55
+ Args:
56
+ x (torch.Tensor): random noise
57
+ t_span (torch.Tensor): n_timesteps interpolated
58
+ shape: (n_timesteps + 1,)
59
+ mu (torch.Tensor): output of encoder
60
+ shape: (batch_size, n_feats, mel_timesteps)
61
+ mask (torch.Tensor): output_mask
62
+ shape: (batch_size, 1, mel_timesteps)
63
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
64
+ shape: (batch_size, spk_emb_dim)
65
+ cond: Not used but kept for future purposes
66
+ """
67
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
68
+
69
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
70
+ # Or in future might add like a return_all_steps flag
71
+ sol = []
72
+
73
+ for step in range(1, len(t_span)):
74
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
75
+
76
+ x = x + dt * dphi_dt
77
+ t = t + dt
78
+ sol.append(x)
79
+ if step < len(t_span) - 1:
80
+ dt = t_span[step + 1] - t
81
+
82
+ return sol[-1]
83
+
84
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
85
+ """Computes diffusion loss
86
+
87
+ Args:
88
+ x1 (torch.Tensor): Target
89
+ shape: (batch_size, n_feats, mel_timesteps)
90
+ mask (torch.Tensor): target mask
91
+ shape: (batch_size, 1, mel_timesteps)
92
+ mu (torch.Tensor): output of encoder
93
+ shape: (batch_size, n_feats, mel_timesteps)
94
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
95
+ shape: (batch_size, spk_emb_dim)
96
+
97
+ Returns:
98
+ loss: conditional flow matching loss
99
+ y: conditional flow
100
+ shape: (batch_size, n_feats, mel_timesteps)
101
+ """
102
+ b, _, t = mu.shape
103
+
104
+ # random timestep
105
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
106
+ # sample noise p(x_0)
107
+ z = torch.randn_like(x1)
108
+
109
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
110
+ u = x1 - (1 - self.sigma_min) * z
111
+
112
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
113
+ torch.sum(mask) * u.shape[1]
114
+ )
115
+ return loss, y
116
+
117
+
118
+ class CFM(BASECFM):
119
+ def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
120
+ super().__init__(
121
+ n_feats=in_channels,
122
+ cfm_params=cfm_params,
123
+ n_spks=n_spks,
124
+ spk_emb_dim=spk_emb_dim,
125
+ )
126
+
127
+ in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
128
+ # Just change the architecture of the estimator here
129
+ self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
130
+
131
+
132
+ class ConditionalCFM(BASECFM):
133
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
134
+ super().__init__(
135
+ n_feats=in_channels,
136
+ cfm_params=cfm_params,
137
+ n_spks=n_spks,
138
+ spk_emb_dim=spk_emb_dim,
139
+ )
140
+ self.t_scheduler = cfm_params.t_scheduler
141
+ self.training_cfg_rate = cfm_params.training_cfg_rate
142
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
143
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
144
+ # Just change the architecture of the estimator here
145
+ self.estimator = estimator
146
+
147
+ @torch.inference_mode()
148
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
149
+ """Forward diffusion
150
+
151
+ Args:
152
+ mu (torch.Tensor): output of encoder
153
+ shape: (batch_size, n_feats, mel_timesteps)
154
+ mask (torch.Tensor): output_mask
155
+ shape: (batch_size, 1, mel_timesteps)
156
+ n_timesteps (int): number of diffusion steps
157
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
158
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
159
+ shape: (batch_size, spk_emb_dim)
160
+ cond: Not used but kept for future purposes
161
+
162
+ Returns:
163
+ sample: generated mel-spectrogram
164
+ shape: (batch_size, n_feats, mel_timesteps)
165
+ """
166
+ z = torch.randn_like(mu) * temperature
167
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
168
+ if self.t_scheduler == 'cosine':
169
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
170
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
171
+
172
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
173
+ """
174
+ Fixed euler solver for ODEs.
175
+ Args:
176
+ x (torch.Tensor): random noise
177
+ t_span (torch.Tensor): n_timesteps interpolated
178
+ shape: (n_timesteps + 1,)
179
+ mu (torch.Tensor): output of encoder
180
+ shape: (batch_size, n_feats, mel_timesteps)
181
+ mask (torch.Tensor): output_mask
182
+ shape: (batch_size, 1, mel_timesteps)
183
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
184
+ shape: (batch_size, spk_emb_dim)
185
+ cond: Not used but kept for future purposes
186
+ """
187
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
188
+ t = t.unsqueeze(dim=0)
189
+
190
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
191
+ # Or in future might add like a return_all_steps flag
192
+ sol = []
193
+
194
+ for step in range(1, len(t_span)):
195
+ dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
196
+ # Classifier-Free Guidance inference introduced in VoiceBox
197
+ if self.inference_cfg_rate > 0:
198
+ cfg_dphi_dt = self.forward_estimator(
199
+ x, mask,
200
+ torch.zeros_like(mu), t,
201
+ torch.zeros_like(spks) if spks is not None else None,
202
+ torch.zeros_like(cond)
203
+ )
204
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
205
+ self.inference_cfg_rate * cfg_dphi_dt)
206
+ x = x + dt * dphi_dt
207
+ t = t + dt
208
+ sol.append(x)
209
+ if step < len(t_span) - 1:
210
+ dt = t_span[step + 1] - t
211
+
212
+ return sol[-1]
213
+
214
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
215
+ if isinstance(self.estimator, torch.nn.Module):
216
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
217
+ else:
218
+ ort_inputs = {
219
+ 'x': x.cpu().numpy(),
220
+ 'mask': mask.cpu().numpy(),
221
+ 'mu': mu.cpu().numpy(),
222
+ 't': t.cpu().numpy(),
223
+ 'spks': spks.cpu().numpy(),
224
+ 'cond': cond.cpu().numpy()
225
+ }
226
+ output = self.estimator.run(None, ort_inputs)[0]
227
+ return torch.tensor(output, dtype=x.dtype, device=x.device)
228
+
229
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
230
+ """Computes diffusion loss
231
+
232
+ Args:
233
+ x1 (torch.Tensor): Target
234
+ shape: (batch_size, n_feats, mel_timesteps)
235
+ mask (torch.Tensor): target mask
236
+ shape: (batch_size, 1, mel_timesteps)
237
+ mu (torch.Tensor): output of encoder
238
+ shape: (batch_size, n_feats, mel_timesteps)
239
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
240
+ shape: (batch_size, spk_emb_dim)
241
+
242
+ Returns:
243
+ loss: conditional flow matching loss
244
+ y: conditional flow
245
+ shape: (batch_size, n_feats, mel_timesteps)
246
+ """
247
+ b, _, t = mu.shape
248
+
249
+ # random timestep
250
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
251
+ if self.t_scheduler == 'cosine':
252
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
253
+ # sample noise p(x_0)
254
+ z = torch.randn_like(x1)
255
+
256
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
257
+ u = x1 - (1 - self.sigma_min) * z
258
+
259
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
260
+ if self.training_cfg_rate > 0:
261
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
262
+ mu = mu * cfg_mask.view(-1, 1, 1)
263
+ spks = spks * cfg_mask.view(-1, 1)
264
+ cond = cond * cfg_mask.view(-1, 1, 1)
265
+
266
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
267
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
268
+ return loss, y
VietTTS/flow/length_regulator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch.nn as nn
3
+ import torch
4
+ from torch.nn import functional as F
5
+ from VietTTS.utils.mask import make_pad_mask
6
+
7
+
8
+ class InterpolateRegulator(nn.Module):
9
+ def __init__(
10
+ self,
11
+ channels: int,
12
+ sampling_ratios: Tuple,
13
+ out_channels: int = None,
14
+ groups: int = 1,
15
+ ):
16
+ super().__init__()
17
+ self.sampling_ratios = sampling_ratios
18
+ out_channels = out_channels or channels
19
+ model = nn.ModuleList([])
20
+ if len(sampling_ratios) > 0:
21
+ for _ in sampling_ratios:
22
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
23
+ norm = nn.GroupNorm(groups, channels)
24
+ act = nn.Mish()
25
+ model.extend([module, norm, act])
26
+ model.append(
27
+ nn.Conv1d(channels, out_channels, 1, 1)
28
+ )
29
+ self.model = nn.Sequential(*model)
30
+
31
+ def forward(self, x, ylens=None):
32
+ # x in (B, T, D)
33
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
34
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
35
+ out = self.model(x).transpose(1, 2).contiguous()
36
+ olens = ylens
37
+ return out * mask, olens
38
+
39
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
40
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
41
+ # x in (B, T, D)
42
+ if x2.shape[1] > 40:
43
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
44
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
45
+ mode='linear')
46
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
47
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
48
+ else:
49
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
50
+ if x1.shape[1] != 0:
51
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
52
+ x = torch.concat([x1, x2], dim=2)
53
+ else:
54
+ x = x2
55
+ out = self.model(x).transpose(1, 2).contiguous()
56
+ return out, mel_len1 + mel_len2
VietTTS/frontend.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import whisper
5
+ import onnxruntime
6
+ import numpy as np
7
+ import torchaudio.compliance.kaldi as kaldi
8
+ from typing import Callable, List, Union
9
+ from functools import partial
10
+ from loguru import logger
11
+
12
+ from VietTTS.utils.frontend_utils import split_text, normalize_text, mel_spectrogram
13
+ from VietTTS.tokenizer.tokenizer import get_tokenizer
14
+
15
+ class TTSFrontEnd:
16
+ def __init__(
17
+ self,
18
+ speech_embedding_model: str,
19
+ speech_tokenizer_model: str,
20
+ ):
21
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ self.tokenizer = get_tokenizer()
23
+ option = onnxruntime.SessionOptions()
24
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
25
+ option.intra_op_num_threads = 1
26
+ self.speech_embedding_session = onnxruntime.InferenceSession(
27
+ speech_embedding_model,
28
+ sess_options=option,
29
+ providers=["CPUExecutionProvider"]
30
+ )
31
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(
32
+ speech_tokenizer_model,
33
+ sess_options=option,
34
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"]
35
+ )
36
+ self.spk2info = {}
37
+
38
+ def _extract_text_token(self, text: str):
39
+ text_token = self.tokenizer.encode(text, allowed_special='all')
40
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
41
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
42
+ return text_token, text_token_len
43
+
44
+ def _extract_speech_token(self, speech: torch.Tensor):
45
+ if speech.shape[1] / 16000 > 30:
46
+ speech = speech[:, :int(16000 * 30)]
47
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
48
+ speech_token = self.speech_tokenizer_session.run(
49
+ None,
50
+ {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
51
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)}
52
+ )[0].flatten().tolist()
53
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
54
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
55
+ return speech_token, speech_token_len
56
+
57
+ def _extract_spk_embedding(self, speech: torch.Tensor):
58
+ feat = kaldi.fbank(
59
+ waveform=speech,
60
+ num_mel_bins=80,
61
+ dither=0,
62
+ sample_frequency=16000
63
+ )
64
+ feat = feat - feat.mean(dim=0, keepdim=True)
65
+ embedding = self.speech_embedding_session.run(
66
+ None,
67
+ {self.speech_embedding_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}
68
+ )[0].flatten().tolist()
69
+ embedding = torch.tensor([embedding]).to(self.device)
70
+ return embedding
71
+
72
+ def _extract_speech_feat(self, speech: torch.Tensor):
73
+ speech_feat = mel_spectrogram(
74
+ y=speech,
75
+ n_fft=1024,
76
+ num_mels=80,
77
+ sampling_rate=22050,
78
+ hop_size=256,
79
+ win_size=1024,
80
+ fmin=0,
81
+ fmax=8000,
82
+ center=False
83
+ ).squeeze(dim=0).transpose(0, 1).to(self.device)
84
+ speech_feat = speech_feat.unsqueeze(dim=0)
85
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
86
+ return speech_feat, speech_feat_len
87
+
88
+ def preprocess_text(self, text, split=True) -> Union[str, List[str]]:
89
+ text = normalize_text(text)
90
+ if split:
91
+ text = list(split_text(
92
+ text=text,
93
+ tokenize=partial(self.tokenizer.encode, allowed_special='all'),
94
+ token_max_n=30,
95
+ token_min_n=10,
96
+ merge_len=5,
97
+ comma_split=False
98
+ ))
99
+ return text
100
+
101
+ def frontend_tts(
102
+ self,
103
+ text: str,
104
+ prompt_speech_16k: Union[np.ndarray, torch.Tensor]
105
+ ) -> dict:
106
+ if isinstance(prompt_speech_16k, np.ndarray):
107
+ prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
108
+
109
+ text_token, text_token_len = self._extract_text_token(text)
110
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
111
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
112
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
113
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
114
+
115
+ model_input = {
116
+ 'text': text_token,
117
+ 'text_len': text_token_len,
118
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
119
+ 'prompt_speech_feat': speech_feat,
120
+ 'prompt_speech_feat_len': speech_feat_len,
121
+ 'llm_embedding': embedding,
122
+ 'flow_embedding': embedding
123
+ }
124
+ return model_input
125
+
126
+
127
+ def frontend_vc(
128
+ self,
129
+ source_speech_16k: Union[np.ndarray, torch.Tensor],
130
+ prompt_speech_16k: Union[np.ndarray, torch.Tensor]
131
+ ) -> dict:
132
+ if isinstance(source_speech_16k, np.ndarray):
133
+ source_speech_16k = torch.from_numpy(source_speech_16k)
134
+ if isinstance(prompt_speech_16k, np.ndarray):
135
+ prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
136
+
137
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
138
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
139
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
140
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
141
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
142
+ model_input = {
143
+ 'source_speech_token': source_speech_token,
144
+ 'source_speech_token_len': source_speech_token_len,
145
+ 'flow_prompt_speech_token': prompt_speech_token,
146
+ 'flow_prompt_speech_token_len': prompt_speech_token_len,
147
+ 'prompt_speech_feat': prompt_speech_feat,
148
+ 'prompt_speech_feat_len': prompt_speech_feat_len,
149
+ 'flow_embedding': embedding
150
+ }
151
+ return model_input
VietTTS/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils import weight_norm
4
+
5
+
6
+ class ConvRNNF0Predictor(nn.Module):
7
+ def __init__(self,
8
+ num_class: int = 1,
9
+ in_channels: int = 80,
10
+ cond_channels: int = 512
11
+ ):
12
+ super().__init__()
13
+
14
+ self.num_class = num_class
15
+ self.condnet = nn.Sequential(
16
+ weight_norm(
17
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
18
+ ),
19
+ nn.ELU(),
20
+ weight_norm(
21
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
22
+ ),
23
+ nn.ELU(),
24
+ weight_norm(
25
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
26
+ ),
27
+ nn.ELU(),
28
+ weight_norm(
29
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
30
+ ),
31
+ nn.ELU(),
32
+ weight_norm(
33
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
34
+ ),
35
+ nn.ELU(),
36
+ )
37
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ x = self.condnet(x)
41
+ x = x.transpose(1, 2)
42
+ return torch.abs(self.classifier(x).squeeze(-1))
VietTTS/hifigan/generator.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HIFI-GAN"""
2
+
3
+ import typing as tp
4
+ import numpy as np
5
+ from scipy.signal import get_window
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import Conv1d
10
+ from torch.nn import ConvTranspose1d
11
+ from torch.nn.utils import remove_weight_norm
12
+ from torch.nn.utils import weight_norm
13
+ from torch.distributions.uniform import Uniform
14
+
15
+ from VietTTS.transformer.activation import Snake
16
+ from VietTTS.utils.common import get_padding
17
+ from VietTTS.utils.common import init_weights
18
+
19
+
20
+ """hifigan based generator implementation.
21
+
22
+ This code is modified from https://github.com/jik876/hifi-gan
23
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
24
+ https://github.com/NVIDIA/BigVGAN
25
+
26
+ """
27
+
28
+
29
+ class ResBlock(torch.nn.Module):
30
+ """Residual block module in HiFiGAN/BigVGAN."""
31
+ def __init__(
32
+ self,
33
+ channels: int = 512,
34
+ kernel_size: int = 3,
35
+ dilations: tp.List[int] = [1, 3, 5],
36
+ ):
37
+ super(ResBlock, self).__init__()
38
+ self.convs1 = nn.ModuleList()
39
+ self.convs2 = nn.ModuleList()
40
+
41
+ for dilation in dilations:
42
+ self.convs1.append(
43
+ weight_norm(
44
+ Conv1d(
45
+ channels,
46
+ channels,
47
+ kernel_size,
48
+ 1,
49
+ dilation=dilation,
50
+ padding=get_padding(kernel_size, dilation)
51
+ )
52
+ )
53
+ )
54
+ self.convs2.append(
55
+ weight_norm(
56
+ Conv1d(
57
+ channels,
58
+ channels,
59
+ kernel_size,
60
+ 1,
61
+ dilation=1,
62
+ padding=get_padding(kernel_size, 1)
63
+ )
64
+ )
65
+ )
66
+ self.convs1.apply(init_weights)
67
+ self.convs2.apply(init_weights)
68
+ self.activations1 = nn.ModuleList([
69
+ Snake(channels, alpha_logscale=False)
70
+ for _ in range(len(self.convs1))
71
+ ])
72
+ self.activations2 = nn.ModuleList([
73
+ Snake(channels, alpha_logscale=False)
74
+ for _ in range(len(self.convs2))
75
+ ])
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ for idx in range(len(self.convs1)):
79
+ xt = self.activations1[idx](x)
80
+ xt = self.convs1[idx](xt)
81
+ xt = self.activations2[idx](xt)
82
+ xt = self.convs2[idx](xt)
83
+ x = xt + x
84
+ return x
85
+
86
+ def remove_weight_norm(self):
87
+ for idx in range(len(self.convs1)):
88
+ remove_weight_norm(self.convs1[idx])
89
+ remove_weight_norm(self.convs2[idx])
90
+
91
+
92
+ class SineGen(torch.nn.Module):
93
+ """ Definition of sine generator
94
+ SineGen(samp_rate, harmonic_num = 0,
95
+ sine_amp = 0.1, noise_std = 0.003,
96
+ voiced_threshold = 0,
97
+ flag_for_pulse=False)
98
+ samp_rate: sampling rate in Hz
99
+ harmonic_num: number of harmonic overtones (default 0)
100
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
101
+ noise_std: std of Gaussian noise (default 0.003)
102
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
103
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
104
+ Note: when flag_for_pulse is True, the first time step of a voiced
105
+ segment is always sin(np.pi) or cos(0)
106
+ """
107
+
108
+ def __init__(self, samp_rate, harmonic_num=0,
109
+ sine_amp=0.1, noise_std=0.003,
110
+ voiced_threshold=0):
111
+ super(SineGen, self).__init__()
112
+ self.sine_amp = sine_amp
113
+ self.noise_std = noise_std
114
+ self.harmonic_num = harmonic_num
115
+ self.sampling_rate = samp_rate
116
+ self.voiced_threshold = voiced_threshold
117
+
118
+ def _f02uv(self, f0):
119
+ # generate uv signal
120
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
121
+ return uv
122
+
123
+ @torch.no_grad()
124
+ def forward(self, f0):
125
+ """
126
+ :param f0: [B, 1, sample_len], Hz
127
+ :return: [B, 1, sample_len]
128
+ """
129
+
130
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
131
+ for i in range(self.harmonic_num + 1):
132
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
133
+
134
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
135
+ u_dist = Uniform(low=-np.pi, high=np.pi)
136
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
137
+ phase_vec[:, 0, :] = 0
138
+
139
+ # generate sine waveforms
140
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
141
+
142
+ # generate uv signal
143
+ uv = self._f02uv(f0)
144
+
145
+ # noise: for unvoiced should be similar to sine_amp
146
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
147
+ # . for voiced regions is self.noise_std
148
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
149
+ noise = noise_amp * torch.randn_like(sine_waves)
150
+
151
+ # first: set the unvoiced part to 0 by uv
152
+ # then: additive noise
153
+ sine_waves = sine_waves * uv + noise
154
+ return sine_waves, uv, noise
155
+
156
+
157
+ class SourceModuleHnNSF(torch.nn.Module):
158
+ """ SourceModule for hn-nsf
159
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
160
+ add_noise_std=0.003, voiced_threshod=0)
161
+ sampling_rate: sampling_rate in Hz
162
+ harmonic_num: number of harmonic above F0 (default: 0)
163
+ sine_amp: amplitude of sine source signal (default: 0.1)
164
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
165
+ note that amplitude of noise in unvoiced is decided
166
+ by sine_amp
167
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
168
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
169
+ F0_sampled (batchsize, length, 1)
170
+ Sine_source (batchsize, length, 1)
171
+ noise_source (batchsize, length 1)
172
+ uv (batchsize, length, 1)
173
+ """
174
+
175
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
176
+ add_noise_std=0.003, voiced_threshod=0):
177
+ super(SourceModuleHnNSF, self).__init__()
178
+
179
+ self.sine_amp = sine_amp
180
+ self.noise_std = add_noise_std
181
+
182
+ # to produce sine waveforms
183
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
184
+ sine_amp, add_noise_std, voiced_threshod)
185
+
186
+ # to merge source harmonics into a single excitation
187
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
188
+ self.l_tanh = torch.nn.Tanh()
189
+
190
+ def forward(self, x):
191
+ """
192
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
193
+ F0_sampled (batchsize, length, 1)
194
+ Sine_source (batchsize, length, 1)
195
+ noise_source (batchsize, length 1)
196
+ """
197
+ # source for harmonic branch
198
+ with torch.no_grad():
199
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
200
+ sine_wavs = sine_wavs.transpose(1, 2)
201
+ uv = uv.transpose(1, 2)
202
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
203
+
204
+ # source for noise branch, in the same shape as uv
205
+ noise = torch.randn_like(uv) * self.sine_amp / 3
206
+ return sine_merge, noise, uv
207
+
208
+
209
+ class HiFTGenerator(nn.Module):
210
+ """
211
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
212
+ https://arxiv.org/abs/2309.09493
213
+ """
214
+ def __init__(
215
+ self,
216
+ in_channels: int = 80,
217
+ base_channels: int = 512,
218
+ nb_harmonics: int = 8,
219
+ sampling_rate: int = 22050,
220
+ nsf_alpha: float = 0.1,
221
+ nsf_sigma: float = 0.003,
222
+ nsf_voiced_threshold: float = 10,
223
+ upsample_rates: tp.List[int] = [8, 8],
224
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
225
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
226
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
227
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
228
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
229
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
230
+ lrelu_slope: float = 0.1,
231
+ audio_limit: float = 0.99,
232
+ f0_predictor: torch.nn.Module = None,
233
+ ):
234
+ super(HiFTGenerator, self).__init__()
235
+
236
+ self.out_channels = 1
237
+ self.nb_harmonics = nb_harmonics
238
+ self.sampling_rate = sampling_rate
239
+ self.istft_params = istft_params
240
+ self.lrelu_slope = lrelu_slope
241
+ self.audio_limit = audio_limit
242
+
243
+ self.num_kernels = len(resblock_kernel_sizes)
244
+ self.num_upsamples = len(upsample_rates)
245
+ self.m_source = SourceModuleHnNSF(
246
+ sampling_rate=sampling_rate,
247
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
248
+ harmonic_num=nb_harmonics,
249
+ sine_amp=nsf_alpha,
250
+ add_noise_std=nsf_sigma,
251
+ voiced_threshod=nsf_voiced_threshold)
252
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
253
+
254
+ self.conv_pre = weight_norm(
255
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
256
+ )
257
+
258
+ # Up
259
+ self.ups = nn.ModuleList()
260
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
261
+ self.ups.append(
262
+ weight_norm(
263
+ ConvTranspose1d(
264
+ base_channels // (2**i),
265
+ base_channels // (2**(i + 1)),
266
+ k,
267
+ u,
268
+ padding=(k - u) // 2,
269
+ )
270
+ )
271
+ )
272
+
273
+ # Down
274
+ self.source_downs = nn.ModuleList()
275
+ self.source_resblocks = nn.ModuleList()
276
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
277
+ downsample_cum_rates = np.cumprod(downsample_rates)
278
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
279
+ if u == 1:
280
+ self.source_downs.append(
281
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
282
+ )
283
+ else:
284
+ self.source_downs.append(
285
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
286
+ )
287
+
288
+ self.source_resblocks.append(
289
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
290
+ )
291
+
292
+ self.resblocks = nn.ModuleList()
293
+ for i in range(len(self.ups)):
294
+ ch = base_channels // (2**(i + 1))
295
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
296
+ self.resblocks.append(ResBlock(ch, k, d))
297
+
298
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
299
+ self.ups.apply(init_weights)
300
+ self.conv_post.apply(init_weights)
301
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
302
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
303
+ self.f0_predictor = f0_predictor
304
+
305
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
306
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
307
+
308
+ har_source, _, _ = self.m_source(f0)
309
+ return har_source.transpose(1, 2)
310
+
311
+ def _stft(self, x):
312
+ spec = torch.stft(
313
+ x,
314
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
315
+ return_complex=True)
316
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
317
+ return spec[..., 0], spec[..., 1]
318
+
319
+ def _istft(self, magnitude, phase):
320
+ magnitude = torch.clip(magnitude, max=1e2)
321
+ real = magnitude * torch.cos(phase)
322
+ img = magnitude * torch.sin(phase)
323
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
324
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
325
+ return inverse_transform
326
+
327
+ def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
328
+ f0 = self.f0_predictor(x)
329
+ s = self._f02source(f0)
330
+
331
+ # use cache_source to avoid glitch
332
+ if cache_source.shape[2] != 0:
333
+ s[:, :, :cache_source.shape[2]] = cache_source
334
+
335
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
336
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
337
+
338
+ x = self.conv_pre(x)
339
+ for i in range(self.num_upsamples):
340
+ x = F.leaky_relu(x, self.lrelu_slope)
341
+ x = self.ups[i](x)
342
+
343
+ if i == self.num_upsamples - 1:
344
+ x = self.reflection_pad(x)
345
+
346
+ # fusion
347
+ si = self.source_downs[i](s_stft)
348
+ si = self.source_resblocks[i](si)
349
+ x = x + si
350
+
351
+ xs = None
352
+ for j in range(self.num_kernels):
353
+ if xs is None:
354
+ xs = self.resblocks[i * self.num_kernels + j](x)
355
+ else:
356
+ xs += self.resblocks[i * self.num_kernels + j](x)
357
+ x = xs / self.num_kernels
358
+
359
+ x = F.leaky_relu(x)
360
+ x = self.conv_post(x)
361
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
362
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
363
+
364
+ x = self._istft(magnitude, phase)
365
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
366
+ return x, s
367
+
368
+ def remove_weight_norm(self):
369
+ print('Removing weight norm...')
370
+ for l in self.ups:
371
+ remove_weight_norm(l)
372
+ for l in self.resblocks:
373
+ l.remove_weight_norm()
374
+ remove_weight_norm(self.conv_pre)
375
+ remove_weight_norm(self.conv_post)
376
+ self.source_module.remove_weight_norm()
377
+ for l in self.source_downs:
378
+ remove_weight_norm(l)
379
+ for l in self.source_resblocks:
380
+ l.remove_weight_norm()
381
+
382
+ @torch.inference_mode()
383
+ def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
384
+ return self.forward(x=mel, cache_source=cache_source)
VietTTS/llm/llm.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Callable, List, Generator
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
6
+ from VietTTS.utils.common import IGNORE_ID
7
+ from VietTTS.transformer.label_smoothing_loss import LabelSmoothingLoss
8
+ from VietTTS.utils.common import th_accuracy
9
+
10
+
11
+ class TransformerLM(torch.nn.Module):
12
+ def __init__(
13
+ self,
14
+ text_encoder_input_size: int,
15
+ llm_input_size: int,
16
+ llm_output_size: int,
17
+ text_token_size: int,
18
+ speech_token_size: int,
19
+ text_encoder: torch.nn.Module,
20
+ llm: torch.nn.Module,
21
+ sampling: Callable,
22
+ length_normalized_loss: bool = True,
23
+ lsm_weight: float = 0.0,
24
+ spk_embed_dim: int = 192,
25
+ ):
26
+ super().__init__()
27
+ self.llm_input_size = llm_input_size
28
+ self.speech_token_size = speech_token_size
29
+ # 1. build text token inputs related modules
30
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
31
+ self.text_encoder = text_encoder
32
+ self.text_encoder_affine_layer = nn.Linear(
33
+ self.text_encoder.output_size(),
34
+ llm_input_size
35
+ )
36
+
37
+ # 2. build speech token language model related modules
38
+ self.sos_eos = 0
39
+ self.task_id = 1
40
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
41
+ self.llm = llm
42
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
43
+ self.criterion_ce = LabelSmoothingLoss(
44
+ size=speech_token_size + 1,
45
+ padding_idx=IGNORE_ID,
46
+ smoothing=lsm_weight,
47
+ normalize_length=length_normalized_loss,
48
+ )
49
+
50
+ # 3. [Optional] build speech token related modules
51
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
52
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
53
+
54
+ # 4. sampling method
55
+ self.sampling = sampling
56
+
57
+ def encode(
58
+ self,
59
+ text: torch.Tensor,
60
+ text_lengths: torch.Tensor,
61
+ ):
62
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
63
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
64
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
65
+ return encoder_out, encoder_out_lens
66
+
67
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
68
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
69
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
70
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
71
+ for i in range(len(text_token))]
72
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
73
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
74
+ return lm_input, lm_input_len
75
+
76
+ def forward(
77
+ self,
78
+ batch: dict,
79
+ device: torch.device,
80
+ ) -> Dict[str, Optional[torch.Tensor]]:
81
+ """
82
+ Args:
83
+ text: (B, L, D)
84
+ text_lengths: (B,)
85
+ audio: (B, T, N) or (B, T)
86
+ audio_lengths: (B,)
87
+ """
88
+ text_token = batch['text_token'].to(device)
89
+ text_token_len = batch['text_token_len'].to(device)
90
+ speech_token = batch['speech_token'].to(device)
91
+ speech_token_len = batch['speech_token_len'].to(device)
92
+ embedding = batch['embedding'].to(device)
93
+
94
+ # 1. prepare llm_target
95
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
96
+ [self.speech_token_size]) for i in range(text_token.size(0))]
97
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
98
+
99
+ # 1. encode text_token
100
+ text_token = self.text_embedding(text_token)
101
+ text_token, text_token_len = self.encode(text_token, text_token_len)
102
+
103
+ # 2. embedding projection
104
+ embedding = F.normalize(embedding, dim=1)
105
+ embedding = self.spk_embed_affine_layer(embedding)
106
+ embedding = embedding.unsqueeze(1)
107
+
108
+ # 3. eos and task_id
109
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
110
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
111
+
112
+ # 4. encode speech_token
113
+ speech_token = self.speech_embedding(speech_token)
114
+
115
+ # 5. unpad and pad
116
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
117
+ task_id_emb, speech_token, speech_token_len)
118
+
119
+ # 6. run lm forward
120
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
121
+ logits = self.llm_decoder(lm_output)
122
+ loss = self.criterion_ce(logits, lm_target)
123
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
124
+ return {'loss': loss, 'acc': acc}
125
+
126
+ def sampling_ids(
127
+ self,
128
+ weighted_scores: torch.Tensor,
129
+ decoded_tokens: List,
130
+ sampling: int,
131
+ ignore_eos: bool = True,
132
+ ):
133
+ while True:
134
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
135
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
136
+ break
137
+ return top_ids
138
+
139
+ @torch.inference_mode()
140
+ def inference(
141
+ self,
142
+ text: torch.Tensor,
143
+ text_len: torch.Tensor,
144
+ prompt_text: torch.Tensor,
145
+ prompt_text_len: torch.Tensor,
146
+ prompt_speech_token: torch.Tensor,
147
+ prompt_speech_token_len: torch.Tensor,
148
+ embedding: torch.Tensor,
149
+ sampling: int = 25,
150
+ max_token_text_ratio: float = 20,
151
+ min_token_text_ratio: float = 2,
152
+ ) -> Generator[torch.Tensor, None, None]:
153
+ device = text.device
154
+ text = torch.concat([prompt_text, text], dim=1)
155
+ text_len += prompt_text_len
156
+ text = self.text_embedding(text)
157
+
158
+ # 1. encode text
159
+ text, text_len = self.encode(text, text_len)
160
+
161
+ # 2. encode embedding
162
+ if embedding.shape[0] != 0:
163
+ embedding = F.normalize(embedding, dim=1)
164
+ embedding = self.spk_embed_affine_layer(embedding)
165
+ embedding = embedding.unsqueeze(dim=1)
166
+ else:
167
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
168
+
169
+ # 3. concat llm_input
170
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
171
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
172
+ if prompt_speech_token_len != 0:
173
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
174
+ else:
175
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
176
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
177
+
178
+ # 4. cal min/max_length
179
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
180
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
181
+
182
+ # 5. step by step decode
183
+ out_tokens = []
184
+ offset = 0
185
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
186
+ for i in range(max_len):
187
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
188
+ att_cache=att_cache, cnn_cache=cnn_cache,
189
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
190
+ device=lm_input.device)).to(torch.bool))
191
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
192
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
193
+ if top_ids == self.speech_token_size:
194
+ break
195
+ # in stream mode, yield token one by one
196
+ yield top_ids
197
+ out_tokens.append(top_ids)
198
+ offset += lm_input.size(1)
199
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
VietTTS/model.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ import torch
3
+ import numpy as np
4
+ import threading
5
+ import time
6
+ from torch.nn import functional as F
7
+ from contextlib import nullcontext
8
+ import uuid
9
+ from VietTTS.utils.common import fade_in_out_audio
10
+
11
+ class TTSModel:
12
+ def __init__(
13
+ self,
14
+ llm: torch.nn.Module,
15
+ flow: torch.nn.Module,
16
+ hift: torch.nn.Module
17
+ ):
18
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ self.llm = llm
20
+ self.flow = flow
21
+ self.hift = hift
22
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
23
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
24
+ self.token_overlap_len = 20
25
+ # mel fade in out
26
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
27
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
28
+ # hift cache
29
+ self.mel_cache_len = 20
30
+ self.source_cache_len = int(self.mel_cache_len * 256)
31
+ # speech fade in out
32
+ self.speech_window = np.hamming(2 * self.source_cache_len)
33
+ # rtf and decoding related
34
+ self.stream_scale_factor = 1
35
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
36
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
37
+ self.lock = threading.Lock()
38
+ # dict used to store session related variable
39
+ self.tts_speech_token_dict = {}
40
+ self.llm_end_dict = {}
41
+ self.mel_overlap_dict = {}
42
+ self.hift_cache_dict = {}
43
+
44
+ def load(self, llm_model, flow_model, hift_model):
45
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
46
+ self.llm.to(self.device).eval()
47
+ self.llm.half()
48
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
49
+ self.flow.to(self.device).eval()
50
+ self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
51
+ self.hift.to(self.device).eval()
52
+
53
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
54
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
55
+ self.llm.text_encoder = llm_text_encoder
56
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
57
+ self.llm.llm = llm_llm
58
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
59
+ self.flow.encoder = flow_encoder
60
+
61
+ def load_onnx(self, flow_decoder_estimator_model):
62
+ import onnxruntime
63
+ option = onnxruntime.SessionOptions()
64
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
65
+ option.intra_op_num_threads = 1
66
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
67
+ del self.flow.decoder.estimator
68
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
69
+
70
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
71
+ with self.llm_context:
72
+ for i in self.llm.inference(
73
+ text=text.to(self.device),
74
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
75
+ prompt_text=prompt_text.to(self.device),
76
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
77
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
78
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
79
+ embedding=llm_embedding.to(self.device).half()
80
+ ):
81
+ self.tts_speech_token_dict[uuid].append(i)
82
+ self.llm_end_dict[uuid] = True
83
+
84
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
85
+ tts_mel = self.flow.inference(
86
+ token=token.to(self.device),
87
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
88
+ prompt_token=prompt_token.to(self.device),
89
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
90
+ prompt_feat=prompt_feat.to(self.device),
91
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
92
+ embedding=embedding.to(self.device)
93
+ )
94
+
95
+ if self.hift_cache_dict[uuid] is not None:
96
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
97
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
98
+ else:
99
+ hift_cache_source = torch.zeros(1, 1, 0)
100
+
101
+ if finalize is False:
102
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
103
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
104
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
105
+ self.hift_cache_dict[uuid] = {
106
+ 'mel': tts_mel[:, :, -self.mel_cache_len:],
107
+ 'source': tts_source[:, :, -self.source_cache_len:],
108
+ 'speech': tts_speech[:, -self.source_cache_len:]
109
+ }
110
+ tts_speech = tts_speech[:, :-self.source_cache_len]
111
+ else:
112
+ if speed != 1.0:
113
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
114
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
115
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
116
+
117
+ tts_speech = fade_in_out_audio(tts_speech)
118
+ return tts_speech
119
+
120
+ def tts(
121
+ self,
122
+ text: str,
123
+ flow_embedding: torch.Tensor,
124
+ llm_embedding: torch.Tensor=torch.zeros(0, 192),
125
+ prompt_text: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
126
+ llm_prompt_speech_token: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
127
+ flow_prompt_speech_token: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
128
+ prompt_speech_feat: torch.Tensor=torch.zeros(1, 0, 80),
129
+ stream: bool=False,
130
+ speed: float=1.0,
131
+ **kwargs
132
+ ):
133
+ # this_uuid is used to track variables related to this inference thread
134
+ this_uuid = str(uuid.uuid1())
135
+ with self.lock:
136
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
137
+ self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
138
+
139
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
140
+ p.start()
141
+
142
+ if stream:
143
+ token_hop_len = self.token_min_hop_len
144
+ while True:
145
+ time.sleep(0.01)
146
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
147
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]).unsqueeze(dim=0)
148
+ this_tts_speech = self.token2wav(
149
+ token=this_tts_speech_token,
150
+ prompt_token=flow_prompt_speech_token,
151
+ prompt_feat=prompt_speech_feat,
152
+ embedding=flow_embedding,
153
+ uuid=this_uuid,
154
+ finalize=False
155
+ )
156
+ yield {'tts_speech': this_tts_speech.cpu()}
157
+ with self.lock:
158
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
159
+ # increase token_hop_len for better speech quality
160
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
161
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
162
+ break
163
+ p.join()
164
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
165
+ this_tts_speech = self.token2wav(
166
+ token=this_tts_speech_token,
167
+ prompt_token=flow_prompt_speech_token,
168
+ prompt_feat=prompt_speech_feat,
169
+ embedding=flow_embedding,
170
+ uuid=this_uuid,
171
+ finalize=True
172
+ )
173
+ yield {'tts_speech': this_tts_speech.cpu()}
174
+ else:
175
+ p.join()
176
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
177
+ this_tts_speech = self.token2wav(
178
+ token=this_tts_speech_token,
179
+ prompt_token=flow_prompt_speech_token,
180
+ prompt_feat=prompt_speech_feat,
181
+ embedding=flow_embedding,
182
+ uuid=this_uuid,
183
+ finalize=True,
184
+ speed=speed
185
+ )
186
+ yield {'tts_speech': this_tts_speech.cpu()}
187
+
188
+ with self.lock:
189
+ self.tts_speech_token_dict.pop(this_uuid)
190
+ self.llm_end_dict.pop(this_uuid)
191
+ self.mel_overlap_dict.pop(this_uuid)
192
+ self.hift_cache_dict.pop(this_uuid)
193
+
194
+ def vc(
195
+ self,
196
+ source_speech_token: torch.Tensor,
197
+ flow_prompt_speech_token: torch.Tensor,
198
+ prompt_speech_feat: torch.Tensor,
199
+ flow_embedding: torch.Tensor,
200
+ stream: bool=False,
201
+ speed: float=1.0,
202
+ **kwargs
203
+ ):
204
+ this_uuid = str(uuid.uuid1())
205
+ with self.lock:
206
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
207
+ self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
208
+
209
+ if stream:
210
+ token_hop_len = self.token_min_hop_len
211
+ while True:
212
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
213
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
214
+ .unsqueeze(dim=0)
215
+ this_tts_speech = self.token2wav(
216
+ token=this_tts_speech_token,
217
+ prompt_token=flow_prompt_speech_token,
218
+ prompt_feat=prompt_speech_feat,
219
+ embedding=flow_embedding,
220
+ uuid=this_uuid,
221
+ finalize=False
222
+ )
223
+ yield {'tts_speech': this_tts_speech.cpu()}
224
+ with self.lock:
225
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
226
+ # increase token_hop_len for better speech quality
227
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
228
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
229
+ break
230
+
231
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
232
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
233
+ this_tts_speech = self.token2wav(
234
+ token=this_tts_speech_token,
235
+ prompt_token=flow_prompt_speech_token,
236
+ prompt_feat=prompt_speech_feat,
237
+ embedding=flow_embedding,
238
+ uuid=this_uuid,
239
+ finalize=True
240
+ )
241
+ yield {'tts_speech': this_tts_speech.cpu()}
242
+ else:
243
+ # deal with all tokens
244
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
245
+ this_tts_speech = self.token2wav(
246
+ token=this_tts_speech_token,
247
+ prompt_token=flow_prompt_speech_token,
248
+ prompt_feat=prompt_speech_feat,
249
+ embedding=flow_embedding,
250
+ uuid=this_uuid,
251
+ finalize=True,
252
+ speed=speed
253
+ )
254
+ yield {'tts_speech': this_tts_speech.cpu()}
255
+
256
+ with self.lock:
257
+ self.tts_speech_token_dict.pop(this_uuid)
258
+ self.llm_end_dict.pop(this_uuid)
259
+ self.mel_overlap_dict.pop(this_uuid)
260
+ self.hift_cache_dict.pop(this_uuid)
VietTTS/models/.cache/huggingface/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *
VietTTS/models/.cache/huggingface/download/.gitattributes.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/.gitattributes.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ a6344aac8c09253b3b630fb776ae94478aa0275b
3
+ 1752523446.3440115
VietTTS/models/.cache/huggingface/download/README.md.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/README.md.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ 41d572572806e892430392203f5e635d461b028a
3
+ 1752523446.4585946
VietTTS/models/.cache/huggingface/download/README_VN.md.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/README_VN.md.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ 68fa27e83e1f31dd41485b567a7475d5abab3732
3
+ 1752523446.479747
VietTTS/models/.cache/huggingface/download/config.yaml.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/config.yaml.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ 062687bc82eb8d2c4ccae395158f2066f4634390
3
+ 1752523445.83575
VietTTS/models/.cache/huggingface/download/flow.pt.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/flow.pt.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ 1411de192039a21d53f0bf1968feb50586ce71d81ea1443f8163f4d1c46c5455
3
+ 1752523560.376818
VietTTS/models/.cache/huggingface/download/hift.pt.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/hift.pt.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ 91e679b6ca1eff71187ffb4f3ab0444935594cdcc20a9bd12afad111ef8d6012
3
+ 1752523474.8549578
VietTTS/models/.cache/huggingface/download/llm.pt.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/llm.pt.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ c1773e5afe16a88ee82e33cf510a07717ce1346d2e74856733d72dc297a9a017
3
+ 1752523690.911262
VietTTS/models/.cache/huggingface/download/speech_embedding.onnx.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/speech_embedding.onnx.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
3
+ 1752523473.7760808
VietTTS/models/.cache/huggingface/download/speech_tokenizer.onnx.lock ADDED
File without changes
VietTTS/models/.cache/huggingface/download/speech_tokenizer.onnx.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ b9f49bb2ef682e162969a6919b0ed2a51a758729
2
+ 56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486
3
+ 1752523583.6788416
VietTTS/models/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
VietTTS/models/README.md ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - vi
4
+ - en
5
+ pipeline_tag: text-to-speech
6
+ license: apache-2.0
7
+ tags:
8
+ - tts
9
+ - text-to-speech
10
+ - vietnamese
11
+ - speech-synthesis
12
+ - speech,
13
+ - viet-tts
14
+ - viettts
15
+ ---
16
+ <!-- # VietTTS: An Open-Source Vietnamese Text to Speech -->
17
+ <p align="center">
18
+ <img src="https://github.com/dangvansam/viet-tts/blob/main/assets/viet-tts-medium.png?raw=true" style="width: 200px">
19
+ <h1 align="center"style="color: white; font-weight: bold; font-family:roboto"><span style="color: white; font-weight: bold; font-family:roboto">VietTTS</span>: An Open-Source Vietnamese Text to Speech</h1>
20
+ </p>
21
+ <p align="center">
22
+ <a href="https://github.com/dangvansam/viet-tts"><img src="https://img.shields.io/github/stars/dangvansam/viet-tts?style=social"></a>
23
+ <a href="LICENSE"><img src="https://img.shields.io/github/license/dangvansam/viet-asr"></a>
24
+ <a href="https://huggingface.co/dangvansam/viet-tts/blob/main/README_VN.md"><img src="https://img.shields.io/badge/README-Tiếng Việt-blue"></a>
25
+ </p>
26
+
27
+ **VietTTS** is an open-source toolkit providing the community with a powerful Vietnamese TTS model, capable of natural voice synthesis and robust voice cloning. Designed for effective experimentation, **VietTTS** supports research and application in Vietnamese voice technologies.
28
+
29
+ ## ⭐ Key Features
30
+ - **TTS**: Text-to-Speech generation with any voice via prompt audio
31
+ - **OpenAI-API-compatible**: Compatible with OpenAI's Text-to-Speech API format
32
+
33
+ ## 🛠️ Installation
34
+
35
+ VietTTS can be installed via a Python installer (Linux only, with Windows and macOS support coming soon) or Docker.
36
+
37
+ ### Python Installer (Python>=3.10)
38
+ ```bash
39
+ git clone https://github.com/dangvansam/viet-tts.git
40
+ cd viet-tts
41
+
42
+ # (Optional) Install Python environment with conda, you could also use virtualenv
43
+ conda create --name viettts python=3.10
44
+ conda activate viettts
45
+
46
+ # Install
47
+ pip install -e . && pip cache purge
48
+ ```
49
+
50
+ ### Docker
51
+
52
+ 1. Install [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), and [CUDA](https://developer.nvidia.com/cuda-downloads).
53
+
54
+ 2. Run the following commands:
55
+ ```bash
56
+ git clone https://github.com/dangvansam/viet-tts.git
57
+ cd viet-tts
58
+
59
+ # Build docker images
60
+ docker compose build
61
+
62
+ # Run with docker-compose - will create server at: http://localhost:8298
63
+ docker compose up -d
64
+
65
+ # Or run with docker run - will create server at: http://localhost:8298
66
+ docker run -itd --gpu=alls -p 8298:8298 -v ./pretrained-models:/app/pretrained-models -n viet-tts-service viet-tts:latest viettts server --host 0.0.0.0 --port 8298
67
+ ```
68
+
69
+ ## 🚀 Usage
70
+
71
+ ### Built-in Voices 🤠
72
+ You can use available voices bellow to synthesize speech.
73
+ <details>
74
+ <summary>Expand</summary>
75
+
76
+ | ID | Voice | Gender | Play Audio |
77
+ |-----|-----------------------|--------|--------------------------------------------------|
78
+ | 1 | nsnd-le-chuc | 👨 | <audio controls src="samples/nsnd-le-chuc.mp3"></audio> |
79
+ | 2 | speechify_10 | 👩 | <audio controls src="samples/speechify_10.wav"></audio> |
80
+ | 3 | atuan | 👨 | <audio controls src="samples/atuan.wav"></audio> |
81
+ | 4 | speechify_11 | 👩 | <audio controls src="samples/speechify_11.wav"></audio> |
82
+ | 5 | cdteam | 👨 | <audio controls src="samples/cdteam.wav"></audio> |
83
+ | 6 | speechify_12 | 👩 | <audio controls src="samples/speechify_12.wav"></audio> |
84
+ | 7 | cross_lingual_prompt | 👩 | <audio controls src="samples/cross_lingual_prompt.wav"></audio> |
85
+ | 8 | speechify_2 | 👩 | <audio controls src="samples/speechify_2.wav"></audio> |
86
+ | 9 | diep-chi | 👨 | <audio controls src="samples/diep-chi.wav"></audio> |
87
+ | 10 | speechify_3 | 👩 | <audio controls src="samples/speechify_3.wav"></audio> |
88
+ | 11 | doremon | 👨 | <audio controls src="samples/doremon.mp3"></audio> |
89
+ | 12 | speechify_4 | 👩 | <audio controls src="samples/speechify_4.wav"></audio> |
90
+ | 13 | jack-sparrow | 👨 | <audio controls src="samples/jack-sparrow.mp3"></audio> |
91
+ | 14 | speechify_5 | 👩 | <audio controls src="samples/speechify_5.wav"></audio> |
92
+ | 15 | nguyen-ngoc-ngan | 👩 | <audio controls src="samples/nguyen-ngoc-ngan.wav"></audio> |
93
+ | 16 | speechify_6 | 👩 | <audio controls src="samples/speechify_6.wav"></audio> |
94
+ | 17 | nu-nhe-nhang | 👩 | <audio controls src="samples/nu-nhe-nhang.wav"></audio> |
95
+ | 18 | speechify_7 | 👩 | <audio controls src="samples/speechify_7.wav"></audio> |
96
+ | 19 | quynh | 👩 | <audio controls src="samples/quynh.wav"></audio> |
97
+ | 20 | speechify_8 | 👩 | <audio controls src="samples/speechify_8.wav"></audio> |
98
+ | 21 | speechify_9 | 👩 | <audio controls src="samples/speechify_9.wav"></audio> |
99
+ | 22 | son-tung-mtp | 👨 | <audio controls src="samples/son-tung-mtp.wav"></audio> |
100
+ | 23 | zero_shot_prompt | 👩 | <audio controls src="samples/zero_shot_prompt.wav"></audio> |
101
+ | 24 | speechify_1 | 👩 | <audio controls src="samples/speechify_1.wav"></audio> |
102
+
103
+ <div>
104
+ </div>
105
+ </details>
106
+
107
+ ### Command Line Interface (CLI)
108
+ The VietTTS Command Line Interface (CLI) allows you to quickly generate speech directly from the terminal. Here's how to use it:
109
+ ```bash
110
+ # Usage
111
+ viettts --help
112
+
113
+ # Start API Server
114
+ viettts server --host 0.0.0.0 --port 8298
115
+
116
+ # List all built-in voices
117
+ viettts show-voices
118
+
119
+ # Synthesize speech from text with built-in voices
120
+ viettts synthesis --text "Xin chào" --voice 0 --output test.wav
121
+
122
+ # Clone voice from a local audio file
123
+ viettts synthesis --text "Xin chào" --voice Download/voice.wav --output cloned.wav
124
+ ```
125
+
126
+ ### API Client
127
+ #### Python (OpenAI Client)
128
+ You need to set environment variables for the OpenAI Client:
129
+ ```bash
130
+ # Set base_url and API key as environment variables
131
+ export OPENAI_BASE_URL=http://localhost:8298
132
+ export OPENAI_API_KEY=viet-tts # not use in current version
133
+ ```
134
+ To create speech from input text:
135
+ ```python
136
+ from pathlib import Path
137
+ from openai import OpenAI
138
+
139
+ client = OpenAI()
140
+
141
+ output_file_path = Path(__file__).parent / "speech.wav"
142
+
143
+ with client.audio.speech.with_streaming_response.create(
144
+ model='tts-1',
145
+ voice='cdteam',
146
+ input='Xin chào Việt Nam.',
147
+ speed=1.0,
148
+ response_format='wav'
149
+ ) as response:
150
+ response.stream_to_file('a.wav')
151
+ ```
152
+
153
+ #### CURL
154
+ ```bash
155
+ # Get all built-in voices
156
+ curl --location http://0.0.0.0:8298/v1/voices
157
+
158
+ # OpenAI format (bult-in voices)
159
+ curl http://localhost:8298/v1/audio/speech \
160
+   -H "Authorization: Bearer viet-tts" \
161
+   -H "Content-Type: application/json" \
162
+   -d '{
163
+     "model": "tts-1",
164
+     "input": "Xin chào Việt Nam.",
165
+     "voice": "son-tung-mtp"
166
+   }' \
167
+   --output speech.wav
168
+
169
+ # API with voice from local file
170
+ curl --location http://0.0.0.0:8298/v1/tts \
171
+ --form 'text="xin chào"' \
172
+ --form 'audio_file=@"/home/viettts/Downloads/voice.mp4"' \
173
+ --output speech.wav
174
+ ```
175
+
176
+ #### Node
177
+ ```js
178
+ import fs from "fs";
179
+ import path from "path";
180
+ import OpenAI from "openai";
181
+
182
+ const openai = new OpenAI();
183
+
184
+ const speechFile = path.resolve("./speech.wav");
185
+
186
+ async function main() {
187
+ const mp3 = await openai.audio.speech.create({
188
+ model: "tts-1",
189
+ voice: "1",
190
+ input: "Xin chào Việt Nam.",
191
+ });
192
+ console.log(speechFile);
193
+ const buffer = Buffer.from(await mp3.arrayBuffer());
194
+ await fs.promises.writeFile(speechFile, buffer);
195
+ }
196
+ main();
197
+ ```
198
+
199
+ ## 🙏 Acknowledgement
200
+ - 💡 Borrowed code from [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
201
+ - 🎙️ VAD model from [silero-vad](https://github.com/snakers4/silero-vad)
202
+ - 📝 Text normalization with [Vinorm](https://github.com/v-nhandt21/Vinorm)
203
+
204
+ ## 📜 License
205
+ The **VietTTS** source code is released under the **Apache 2.0 License**. Pre-trained models and audio samples are licensed under the **CC BY-NC License**, based on an in-the-wild dataset. We apologize for any inconvenience this may cause.
206
+
207
+ ## ⚠️ Disclaimer
208
+ The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
209
+
210
+ ## 💬 Contact
211
+ - Facebook: https://fb.com/sam.rngd
212
+ - GitHub: https://github.com/dangvansam
213
+ - Email: [email protected]
VietTTS/models/README_VN.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="https://github.com/dangvansam/viet-tts/blob/main/assets/viet-tts-medium.png?raw=true" style="width: 200px">
3
+   <h1 align="center" style="color: white; font-weight: bold; font-family:roboto"><span style="color: white; font-weight: bold; font-family:roboto">VietTTS</span>: Công cụ chuyển văn bản thành giọng nói tiếng Việt mã nguồn mở</h1>
4
+ </p>
5
+ <p align="center">
6
+ <a href="https://github.com/dangvansam/viet-tts"><img src="https://img.shields.io/github/stars/dangvansam/viet-tts?style=social"></a>
7
+ <a href="LICENSE"><img src="https://img.shields.io/github/license/dangvansam/viet-asr"></a>
8
+ <a href="https://huggingface.co/dangvansam/viet-tts/blob/main/README.md"><img src="https://img.shields.io/badge/README-English-blue"></a>
9
+ </p>
10
+
11
+ **VietTTS** là một bộ công cụ mã nguồn mở cung cấp mô hình TTS tiếng Việt mạnh mẽ, cho phép tổng hợp giọng nói tự nhiên và tạo giọng nói mới. **VietTTS** hỗ trợ nghiên cứu và ứng dụng trong công nghệ giọng nói tiếng Việt.
12
+
13
+ ## ⭐ Tính năng nổi bật
14
+ - **TTS**: Tổng hợp giọng nói từ văn bản với bất kỳ giọng nào qua audio mẫu
15
+ - **OpenAI-API-compatible**: Tương thích với API Text to Speech OpenAI
16
+
17
+ ## 🛠️ Cài đặt
18
+ VietTTS có thể được cài đặt qua trình cài đặt Python (chỉ hỗ trợ Linux, Windows và macOS sẽ có trong tương lai) hoặc Docker.
19
+
20
+ ### Trình cài đặt Python (Python>=3.10)
21
+
22
+ ```bash
23
+ git clone https://github.com/dangvansam/viet-tts.git
24
+ cd viet-tts
25
+
26
+ # (Tùy chọn) Tạo môi trường Python với conda hoặc dùng virtualenv
27
+ conda create --name viettts python=3.10
28
+ conda activate viettts
29
+
30
+ # Cài đặt
31
+ pip install -e . && pip cache purge
32
+ ```
33
+
34
+ ### Docker
35
+ 1. Cài đặt [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), và [CUDA](https://developer.nvidia.com/cuda-downloads).
36
+
37
+ 2. Chạy các lệnh sau:
38
+ ```bash
39
+ git clone https://github.com/dangvansam/viet-tts.git
40
+ cd viet-tts
41
+
42
+ # Xây dựng hình ảnh docker
43
+ docker compose build
44
+
45
+ # Chạy bằng docker-compose - tạo server tại: http://localhost:8298
46
+ docker compose up -d
47
+
48
+ # Chạy bằng docker run - tạo server tại: http://localhost:8298
49
+ docker run -itd --gpu=alls -p 8298:8298 -v ./pretrained-models:/app/pretrained-models -n viet-tts-service viet-tts:latest viettts server --host 0.0.0.0 --port 8298
50
+ ```
51
+
52
+ ## 🚀 Sử dụng
53
+
54
+ ### Giọng nói tích hợp 🤠
55
+ Bạn có thể sử dụng các giọng nói có sẵn dưới đây để tổng hợp giọng nói.
56
+ <details>
57
+   <summary>Mở rộng</summary>
58
+
59
+ | ID  | Giọng                   | Giới tính | Phát âm thanh                                   |
60
+ |-----|--------------------------|-----------|-------------------------------------------------|
61
+ | 1   | nsnd-le-chuc             | 👨        | <audio controls src="samples/nsnd-le-chuc.mp3"></audio> |
62
+ | 2   | speechify_10             | 👩        | <audio controls src="samples/speechify_10.wav"></audio> |
63
+ | 3   | atuan                    | 👨        | <audio controls src="samples/atuan.wav"></audio>        |
64
+ | 4   | speechify_11             | 👩        | <audio controls src="samples/speechify_11.wav"></audio> |
65
+ | 5   | cdteam                   | 👨        | <audio controls src="samples/cdteam.wav"></audio>       |
66
+ | 6   | speechify_12             | 👩        | <audio controls src="samples/speechify_12.wav"></audio> |
67
+ | 7   | cross_lingual_prompt     | 👩        | <audio controls src="samples/cross_lingual_prompt.wav"></audio> |
68
+ | 8   | speechify_2              | 👩        | <audio controls src="samples/speechify_2.wav"></audio>   |
69
+ | 9   | diep-chi                 | 👨        | <audio controls src="samples/diep-chi.wav"></audio>      |
70
+ | 10  | speechify_3              | 👩        | <audio controls src="samples/speechify_3.wav"></audio>   |
71
+ | 11  | doremon                  | 👨        | <audio controls src="samples/doremon.mp3"></audio>       |
72
+ | 12  | speechify_4              | 👩        | <audio controls src="samples/speechify_4.wav"></audio>   |
73
+ | 13  | jack-sparrow             | 👨        | <audio controls src="samples/jack-sparrow.mp3"></audio> |
74
+ | 14  | speechify_5              | 👩        | <audio controls src="samples/speechify_5.wav"></audio>   |
75
+ | 15  | nguyen-ngoc-ngan         | 👩        | <audio controls src="samples/nguyen-ngoc-ngan.wav"></audio> |
76
+ | 16  | speechify_6              | 👩        | <audio controls src="samples/speechify_6.wav"></audio>   |
77
+ | 17  | nu-nhe-nhang             | 👩        | <audio controls src="samples/nu-nhe-nhang.wav"></audio> |
78
+ | 18  | speechify_7              | 👩        | <audio controls src="samples/speechify_7.wav"></audio>   |
79
+ | 19  | quynh                    | 👩        | <audio controls src="samples/quynh.wav"></audio>         |
80
+ | 20  | speechify_8              | 👩        | <audio controls src="samples/speechify_8.wav"></audio>   |
81
+ | 21  | speechify_9              | 👩        | <audio controls src="samples/speechify_9.wav"></audio>   |
82
+ | 22  | son-tung-mtp             | 👨        | <audio controls src="samples/son-tung-mtp.wav"></audio> |
83
+ | 23  | zero_shot_prompt         | 👩        | <audio controls src="samples/zero_shot_prompt.wav"></audio> |
84
+ | 24  | speechify_1              | 👩        | <audio controls src="samples/speechify_1.wav"></audio>   |
85
+
86
+   <div>
87
+
88
+   </div>
89
+
90
+ </details>
91
+
92
+ ### Thực thi với lệnh (CLI)
93
+
94
+ Giao diện dòng lệnh VietTTS cho phép bạn tạo giọng nói từ terminal. Cách sử dụng:
95
+
96
+ ```bash
97
+ # Hướng dẫn sử dụng
98
+ viettts --help
99
+
100
+ # Khởi động API Server
101
+ viettts server --host 0.0.0.0 --port 8298
102
+
103
+ # Xem tất cả các giọng nói có sẵn
104
+ viettts show-voices
105
+
106
+ # Tổng hợp giọng nói từ văn bản với giọng có sẵn
107
+ viettts synthesis --text "Xin chào" --voice 0 --output test.wav
108
+
109
+ # Sao chép giọng từ audio file bất kì
110
+ viettts synthesis --text "Xin chào" --voice Download/voice.wav --output cloned.wav
111
+ ```
112
+
113
+ ### API Client
114
+ #### Python (OpenAI Client)
115
+ Thiết lập biến môi trường cho OpenAI Client:
116
+
117
+ ```bash
118
+ # Thiết lập base_url và API key như biến môi trường
119
+ export OPENAI_BASE_URL=http://localhost:8298
120
+ export OPENAI_API_KEY=viet-tts # không dùng trong phiên bản hiện tại
121
+ ```
122
+
123
+ Để tạo giọng nói từ văn bản đầu vào:
124
+
125
+ ```python
126
+ from pathlib import Path
127
+ from openai import OpenAI
128
+
129
+
130
+
131
+ client = OpenAI()
132
+ output_file_path = Path(__file__).parent / "speech.wav"
133
+
134
+ with client.audio.speech.with_streaming_response.create(
135
+ model='tts-1',
136
+ voice='cdteam',
137
+ input='Xin chào Việt Nam.',
138
+ speed=1.0,
139
+ response_format='wav'
140
+ ) as response:
141
+ response.stream_to_file('a.wav')
142
+ ```
143
+
144
+ #### CURL
145
+ ```bash
146
+ # Lấy danh sách giọng có sẵn
147
+ curl --location http://0.0.0.0:8298/v1/voices
148
+
149
+ # OpenAI API format
150
+ curl http://localhost:8298/v1/audio/speech \
151
+   -H "Authorization: Bearer viet-tts" \
152
+   -H "Content-Type: application/json" \
153
+   -d '{
154
+     "model": "tts-1",
155
+     "input": "Xin chào Việt Nam.",
156
+     "voice": "son-tung-mtp"
157
+   }' \
158
+   --output speech.wav
159
+
160
+ # API với giọng từ file local
161
+ curl --location http://0.0.0.0:8298/v1/tts \
162
+ --form 'text="xin chào"' \
163
+ --form 'audio_file=@"/home/viettts/Downloads/voice.mp4"' \
164
+ --output speech.wav
165
+ ```
166
+
167
+ #### Node
168
+ ```js
169
+ import fs from "fs";
170
+ import path from "path";
171
+ import OpenAI from "openai";
172
+
173
+ const openai = new OpenAI();
174
+ const speechFile = path.resolve("./speech.wav");
175
+
176
+ async function main() {
177
+   const mp3 = await openai.audio.speech.create({
178
+     model: "tts-1",
179
+     voice: "1",
180
+     input: "Xin chào Việt Nam.",
181
+   });
182
+   console.log(speechFile);
183
+   const buffer = Buffer.from(await mp3.arrayBuffer());
184
+   await fs.promises.writeFile(speechFile, buffer);
185
+ }
186
+ main();
187
+ ```
188
+
189
+ ## 🙏 Mã liên quan
190
+ - 💡 Sử dụng mã từ [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
191
+ - 🎙️ Mô hình VAD từ [silero-vad](https://github.com/snakers4/silero-vad)
192
+ - 📝 Chuẩn hóa văn bản với [Vinorm](https://github.com/v-nhandt21/Vinorm)
193
+
194
+ ## 📜 Giấy phép
195
+ Mã nguồn của **VietTTS** được cấp phép theo **Apache 2.0 License**. Mô hình và mẫu âm thanh huấn luyện được cấp phép theo **CC BY-NC License**, dựa trên tập dữ liệu từ internet. Xin lỗi nếu điều này gây bất tiện.
196
+
197
+ ## ⚠️ Tuyên bố miễn trừ trách nhiệm
198
+ Nội dung trên chỉ phục vụ mục đích học thuật và nhằm trình bày khả năng kỹ thuật. Một số ví dụ lấy từ internet. Nếu nội dung vi phạm quyền của bạn, vui lòng liên hệ để được gỡ bỏ.
199
+
200
+ ## 💬 Liên hệ
201
+ - Facebook: https://fb.com/sam.rngd
202
+ - GitHub: https://github.com/dangvansam
203
+ - Email: [email protected]
VietTTS/models/config.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __set_seed1: !apply:random.seed [1986]
2
+ __set_seed2: !apply:numpy.random.seed [1986]
3
+ __set_seed3: !apply:torch.manual_seed [1986]
4
+ __set_seed4: !apply:torch.cuda.manual_seed_all [1986]
5
+
6
+ sample_rate: 22050
7
+ text_encoder_input_size: 512
8
+ llm_input_size: 1024
9
+ llm_output_size: 1024
10
+ spk_embed_dim: 192
11
+
12
+ llm: !new:VietTTS.llm.llm.TransformerLM
13
+ text_encoder_input_size: !ref <text_encoder_input_size>
14
+ llm_input_size: !ref <llm_input_size>
15
+ llm_output_size: !ref <llm_output_size>
16
+ text_token_size: 60515
17
+ speech_token_size: 4096
18
+ length_normalized_loss: True
19
+ lsm_weight: 0
20
+ spk_embed_dim: !ref <spk_embed_dim>
21
+ text_encoder: !new:VietTTS.transformer.encoder.ConformerEncoder
22
+ input_size: !ref <text_encoder_input_size>
23
+ output_size: 1024
24
+ attention_heads: 16
25
+ linear_units: 4096
26
+ num_blocks: 6
27
+ dropout_rate: 0.1
28
+ positional_dropout_rate: 0.1
29
+ attention_dropout_rate: 0.0
30
+ normalize_before: True
31
+ input_layer: 'linear'
32
+ pos_enc_layer_type: 'rel_pos_espnet'
33
+ selfattention_layer_type: 'rel_selfattn'
34
+ use_cnn_module: False
35
+ macaron_style: False
36
+ use_dynamic_chunk: False
37
+ use_dynamic_left_chunk: False
38
+ static_chunk_size: 1
39
+ llm: !new:VietTTS.transformer.encoder.TransformerEncoder
40
+ input_size: !ref <llm_input_size>
41
+ output_size: !ref <llm_output_size>
42
+ attention_heads: 16
43
+ linear_units: 4096
44
+ num_blocks: 14
45
+ dropout_rate: 0.1
46
+ positional_dropout_rate: 0.1
47
+ attention_dropout_rate: 0.0
48
+ input_layer: 'linear_legacy'
49
+ pos_enc_layer_type: 'rel_pos_espnet'
50
+ selfattention_layer_type: 'rel_selfattn'
51
+ static_chunk_size: 1
52
+ sampling: !name:VietTTS.utils.common.ras_sampling
53
+ top_p: 0.8
54
+ top_k: 25
55
+ win_size: 10
56
+ tau_r: 0.1
57
+
58
+ flow: !new:VietTTS.flow.flow.MaskedDiffWithXvec
59
+ input_size: 512
60
+ output_size: 80
61
+ spk_embed_dim: !ref <spk_embed_dim>
62
+ output_type: 'mel'
63
+ vocab_size: 4096
64
+ input_frame_rate: 25
65
+ only_mask_loss: True
66
+ encoder: !new:VietTTS.transformer.encoder.ConformerEncoder
67
+ output_size: 512
68
+ attention_heads: 8
69
+ linear_units: 2048
70
+ num_blocks: 6
71
+ dropout_rate: 0.1
72
+ positional_dropout_rate: 0.1
73
+ attention_dropout_rate: 0.1
74
+ normalize_before: True
75
+ input_layer: 'linear'
76
+ pos_enc_layer_type: 'rel_pos_espnet'
77
+ selfattention_layer_type: 'rel_selfattn'
78
+ input_size: 512
79
+ use_cnn_module: False
80
+ macaron_style: False
81
+ length_regulator: !new:VietTTS.flow.length_regulator.InterpolateRegulator
82
+ channels: 80
83
+ sampling_ratios: [1, 1, 1, 1]
84
+ decoder: !new:VietTTS.flow.flow_matching.ConditionalCFM
85
+ in_channels: 240
86
+ n_spks: 1
87
+ spk_emb_dim: 80
88
+ cfm_params: !new:omegaconf.DictConfig
89
+ content:
90
+ sigma_min: 1e-06
91
+ solver: 'euler'
92
+ t_scheduler: 'cosine'
93
+ training_cfg_rate: 0.2
94
+ inference_cfg_rate: 0.7
95
+ reg_loss_type: 'l1'
96
+ estimator: !new:VietTTS.flow.decoder.ConditionalDecoder
97
+ in_channels: 320
98
+ out_channels: 80
99
+ channels: [256, 256]
100
+ dropout: 0.0
101
+ attention_head_dim: 64
102
+ n_blocks: 4
103
+ num_mid_blocks: 12
104
+ num_heads: 8
105
+ act_fn: 'gelu'
106
+
107
+ hift: !new:VietTTS.hifigan.generator.HiFTGenerator
108
+ in_channels: 80
109
+ base_channels: 512
110
+ nb_harmonics: 8
111
+ sampling_rate: !ref <sample_rate>
112
+ nsf_alpha: 0.1
113
+ nsf_sigma: 0.003
114
+ nsf_voiced_threshold: 10
115
+ upsample_rates: [8, 8]
116
+ upsample_kernel_sizes: [16, 16]
117
+ istft_params:
118
+ n_fft: 16
119
+ hop_len: 4
120
+ resblock_kernel_sizes: [3, 7, 11]
121
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
122
+ source_resblock_kernel_sizes: [7, 11]
123
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
124
+ lrelu_slope: 0.1
125
+ audio_limit: 0.99
126
+ f0_predictor: !new:VietTTS.hifigan.f0_predictor.ConvRNNF0Predictor
127
+ num_class: 1
128
+ in_channels: 80
129
+ cond_channels: 512
VietTTS/models/flow.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1411de192039a21d53f0bf1968feb50586ce71d81ea1443f8163f4d1c46c5455
3
+ size 419901370
VietTTS/models/hift.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91e679b6ca1eff71187ffb4f3ab0444935594cdcc20a9bd12afad111ef8d6012
3
+ size 81896716
VietTTS/models/llm.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1773e5afe16a88ee82e33cf510a07717ce1346d2e74856733d72dc297a9a017
3
+ size 1260740644
VietTTS/models/speech_embedding.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
3
+ size 28303423
VietTTS/models/speech_tokenizer.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486
3
+ size 522625011
VietTTS/samples/cdteam.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6adf2c56a4dabcbfcc427df36f9dd268efb1153881b682071a80ad80ae4f0ac5
3
+ size 1290116
VietTTS/samples/cross_lingual_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:353a7715c2e4811f4045658b29d1ce67ecad5120e09de10ce890f1763aab486c
3
+ size 606404
VietTTS/samples/diep-chi.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af5ae833e3d2213c09704d83535d5416744c8372368edc1005a2587c631c87ea
3
+ size 1272260
VietTTS/samples/doremon.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c738f39e46db361e2e02f9064ef736c75b7dcf145873682619c123259b04762
3
+ size 761386
VietTTS/samples/jack-sparrow.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8729fd534af6f354c39bdb90cfa352654f876eaa0ad5759bf617797c9388878c
3
+ size 177121
VietTTS/samples/nguyen-ngoc-ngan.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b9ca5b01b44fdd2be7416644fbc7d463248405554fff24cf1eaaed93bd31cea
3
+ size 5351668
VietTTS/samples/nsnd-le-chuc.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d620e721295afdba2cf3d9d4e772f10cd5b416ef14c8d11284431657deeb97
3
+ size 1416881
VietTTS/samples/nu-nhe-nhang.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a120871b168489a33b7f3188764b0f973583bf5284bd96cd805d9e6256a7e45
3
+ size 710734
VietTTS/samples/quynh.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7c5ff2187ca7a5e1371e4ab48cffdadbc38684da8ac3bcae598122ef294401f
3
+ size 2178450
VietTTS/samples/son-tung-mtp.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5b22e5beb4e71b5405f7839656c4d5d71fc34f03b65a58ab27eb86a7f3dfe52
3
+ size 1473048