Upload 9 files
Browse files- .gitattributes +35 -35
- README.md +62 -0
- config.json +82 -0
- model.safetensors +3 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +4 -0
- tokenizer_config.json +12 -0
- trainer.py +270 -0
- vocab.json +86 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
---
|
3 |
+
license: cc-by-nc-4.0
|
4 |
+
tags:
|
5 |
+
- tts
|
6 |
+
- gpt2
|
7 |
+
- vae
|
8 |
+
pipeline_tag: text-to-speech
|
9 |
+
---
|
10 |
+
|
11 |
+
# Malayalam Text-to-Speech
|
12 |
+
|
13 |
+
This repository contains the **Malayalam (mal)** language text-to-speech (TTS) model checkpoint.
|
14 |
+
|
15 |
+
## Model Details
|
16 |
+
|
17 |
+
Sura (**S**tochastic **U**nified **R**epresentation for **A**dversarial learning) is an advanced speech synthesis model that generates speech waveforms conditioned on input text sequences. It is based on a conditional variational autoencoder (VAE) architecture, consisting of a posterior encoder, a decoder, and a conditional prior.
|
18 |
+
|
19 |
+
The model's text encoder is built on GPT-2, while the decoder is a VAE with 124M parameters. The flow-based module predicts spectrogram-based acoustic features, which is composed of the GPT-2-based encoder and cascaded dense layers. The spectrogram is then transformed into a speech waveform using a stack of transposed convolutional layers. To capture the one-to-many nature of TTS, where the same text can be spoken in multiple ways, the model also includes a stochastic duration predictor, allowing for varied speech rhythms from the same text input.
|
20 |
+
|
21 |
+
Sura is trained end-to-end using a combination of losses from the variational lower bound and adversarial training techniques. During inference, the text encodings are up-sampled based on the predicted durations, and subsequently mapped into the waveform via the flow module and the VAE decoder. Due to the stochastic nature of the duration predictor, the model is non-deterministic and requires a fixed seed to produce identical speech outputs.
|
22 |
+
## Usage
|
23 |
+
|
24 |
+
```
|
25 |
+
pip install --upgrade transformers accelerate
|
26 |
+
```
|
27 |
+
|
28 |
+
Then, run inference with the following code-snippet:
|
29 |
+
|
30 |
+
```python
|
31 |
+
from transformers import VitsModel, AutoTokenizer
|
32 |
+
import torch
|
33 |
+
|
34 |
+
model = VitsModel.from_pretrained("aoxo/gpt2-vae-tts-mal")
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/gpt2-vae-tts-mal")
|
36 |
+
|
37 |
+
text = "കള്ളാ കടയാടി മോനെ"
|
38 |
+
inputs = tokenizer(text, return_tensors="pt")
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
output = model(**inputs).waveform
|
42 |
+
```
|
43 |
+
|
44 |
+
The resulting waveform can be saved as a `.wav` file:
|
45 |
+
|
46 |
+
```python
|
47 |
+
import scipy
|
48 |
+
|
49 |
+
scipy.io.wavfile.write("kadayadi_mone.wav", rate=model.config.sampling_rate, data=output)
|
50 |
+
```
|
51 |
+
|
52 |
+
Or displayed in a Jupyter Notebook / Google Colab:
|
53 |
+
|
54 |
+
```python
|
55 |
+
from IPython.display import Audio
|
56 |
+
|
57 |
+
Audio(output, rate=model.config.sampling_rate)
|
58 |
+
```
|
59 |
+
|
60 |
+
## License
|
61 |
+
|
62 |
+
The model is licensed as **CC-BY-NC 4.0**.
|
config.json
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.1,
|
3 |
+
"architectures": [
|
4 |
+
"VitsModel"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.1,
|
7 |
+
"depth_separable_channels": 2,
|
8 |
+
"depth_separable_num_layers": 3,
|
9 |
+
"duration_predictor_dropout": 0.5,
|
10 |
+
"duration_predictor_filter_channels": 256,
|
11 |
+
"duration_predictor_flow_bins": 10,
|
12 |
+
"duration_predictor_kernel_size": 3,
|
13 |
+
"duration_predictor_num_flows": 4,
|
14 |
+
"duration_predictor_tail_bound": 5.0,
|
15 |
+
"ffn_dim": 768,
|
16 |
+
"ffn_kernel_size": 3,
|
17 |
+
"flow_size": 192,
|
18 |
+
"hidden_act": "relu",
|
19 |
+
"hidden_dropout": 0.1,
|
20 |
+
"hidden_size": 192,
|
21 |
+
"initializer_range": 0.02,
|
22 |
+
"layer_norm_eps": 1e-05,
|
23 |
+
"layerdrop": 0.1,
|
24 |
+
"leaky_relu_slope": 0.1,
|
25 |
+
"model_type": "vits",
|
26 |
+
"noise_scale": 0.667,
|
27 |
+
"noise_scale_duration": 0.8,
|
28 |
+
"num_attention_heads": 2,
|
29 |
+
"num_hidden_layers": 6,
|
30 |
+
"num_speakers": 1,
|
31 |
+
"posterior_encoder_num_wavenet_layers": 16,
|
32 |
+
"prior_encoder_num_flows": 4,
|
33 |
+
"prior_encoder_num_wavenet_layers": 4,
|
34 |
+
"resblock_dilation_sizes": [
|
35 |
+
[
|
36 |
+
1,
|
37 |
+
3,
|
38 |
+
5
|
39 |
+
],
|
40 |
+
[
|
41 |
+
1,
|
42 |
+
3,
|
43 |
+
5
|
44 |
+
],
|
45 |
+
[
|
46 |
+
1,
|
47 |
+
3,
|
48 |
+
5
|
49 |
+
]
|
50 |
+
],
|
51 |
+
"resblock_kernel_sizes": [
|
52 |
+
3,
|
53 |
+
7,
|
54 |
+
11
|
55 |
+
],
|
56 |
+
"sampling_rate": 16000,
|
57 |
+
"speaker_embedding_size": 0,
|
58 |
+
"speaking_rate": 1.0,
|
59 |
+
"spectrogram_bins": 513,
|
60 |
+
"torch_dtype": "float32",
|
61 |
+
"transformers_version": "4.33.0.dev0",
|
62 |
+
"upsample_initial_channel": 512,
|
63 |
+
"upsample_kernel_sizes": [
|
64 |
+
16,
|
65 |
+
16,
|
66 |
+
4,
|
67 |
+
4
|
68 |
+
],
|
69 |
+
"upsample_rates": [
|
70 |
+
8,
|
71 |
+
8,
|
72 |
+
2,
|
73 |
+
2
|
74 |
+
],
|
75 |
+
"use_bias": true,
|
76 |
+
"use_stochastic_duration_prediction": true,
|
77 |
+
"vocab_size": 84,
|
78 |
+
"wavenet_dilation_rate": 1,
|
79 |
+
"wavenet_dropout": 0.0,
|
80 |
+
"wavenet_kernel_size": 5,
|
81 |
+
"window_size": 4
|
82 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a97a1e677ec67e05124b799dadd66630181fe9c29beb4e590454689ff8f698c5
|
3 |
+
size 145262840
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:046f0de3c235f75b96b478bab98bebcfe8468a11add65a8cde3f88f2cb2c63b9
|
3 |
+
size 145424050
|
special_tokens_map.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"pad_token": "—",
|
3 |
+
"unk_token": "<unk>"
|
4 |
+
}
|
tokenizer_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_blank": true,
|
3 |
+
"clean_up_tokenization_spaces": true,
|
4 |
+
"is_uroman": false,
|
5 |
+
"language": "mal",
|
6 |
+
"model_max_length": 1000000000000000019884624838656,
|
7 |
+
"normalize": true,
|
8 |
+
"pad_token": "—",
|
9 |
+
"phonemize": false,
|
10 |
+
"tokenizer_class": "VitsTokenizer",
|
11 |
+
"unk_token": "<unk>"
|
12 |
+
}
|
trainer.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from transformers import GPT2Tokenizer, GPT2Model
|
7 |
+
from torchaudio.transforms import MelSpectrogram, InverseMelScale, GriffinLim
|
8 |
+
import torchaudio
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
12 |
+
from torch.amp import GradScaler, autocast
|
13 |
+
|
14 |
+
class TextToSpeechDataset(Dataset):
|
15 |
+
def __init__(self, text_files, audio_files, tokenizer, mel_transform, max_length=512):
|
16 |
+
self.text_files = text_files
|
17 |
+
self.audio_files = audio_files
|
18 |
+
self.tokenizer = tokenizer
|
19 |
+
self.mel_transform = mel_transform
|
20 |
+
self.max_length = max_length
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.text_files)
|
24 |
+
|
25 |
+
def __getitem__(self, idx):
|
26 |
+
# Load text
|
27 |
+
with open(self.text_files[idx], 'r') as f:
|
28 |
+
text = f.read().strip()
|
29 |
+
|
30 |
+
# Tokenize text
|
31 |
+
text_tokens = self.tokenizer.encode(
|
32 |
+
text,
|
33 |
+
truncation=True,
|
34 |
+
padding='max_length',
|
35 |
+
max_length=self.max_length,
|
36 |
+
return_tensors="pt"
|
37 |
+
).squeeze(0)
|
38 |
+
|
39 |
+
# Load audio and convert to mel spectrogram
|
40 |
+
waveform, sample_rate = torchaudio.load(self.audio_files[idx])
|
41 |
+
mel_spec = self.mel_transform(waveform)
|
42 |
+
|
43 |
+
return text_tokens, mel_spec.squeeze(0)
|
44 |
+
|
45 |
+
def collate_fn(batch):
|
46 |
+
text_tokens, mel_specs = zip(*batch)
|
47 |
+
|
48 |
+
# Pad text tokens
|
49 |
+
max_text_len = max(tokens.size(0) for tokens in text_tokens)
|
50 |
+
text_tokens_padded = torch.stack([
|
51 |
+
torch.cat([tokens, torch.zeros(max_text_len - tokens.size(0), dtype=tokens.dtype)], dim=0)
|
52 |
+
if tokens.size(0) < max_text_len
|
53 |
+
else tokens[:max_text_len]
|
54 |
+
for tokens in text_tokens
|
55 |
+
])
|
56 |
+
|
57 |
+
# Pad mel spectrograms
|
58 |
+
max_mel_len = max(spec.size(1) for spec in mel_specs)
|
59 |
+
mel_specs_padded = torch.stack([
|
60 |
+
F.pad(spec, (0, max_mel_len - spec.size(1)))
|
61 |
+
if spec.size(1) < max_mel_len
|
62 |
+
else spec[:, :max_mel_len]
|
63 |
+
for spec in mel_specs
|
64 |
+
])
|
65 |
+
|
66 |
+
return text_tokens_padded, mel_specs_padded
|
67 |
+
|
68 |
+
class VAEDecoder(nn.Module):
|
69 |
+
def __init__(self, latent_dim, mel_channels=80):
|
70 |
+
super().__init__()
|
71 |
+
# Encoder part (probabilistic)
|
72 |
+
self.fc_mu = nn.Linear(latent_dim, latent_dim)
|
73 |
+
self.fc_var = nn.Linear(latent_dim, latent_dim)
|
74 |
+
|
75 |
+
# Decoder part
|
76 |
+
self.decoder_layers = nn.Sequential(
|
77 |
+
nn.Linear(latent_dim, 512),
|
78 |
+
nn.ReLU(),
|
79 |
+
nn.Linear(512, 1024),
|
80 |
+
nn.ReLU(),
|
81 |
+
nn.Linear(1024, mel_channels * 80), # Output mel spectrogram
|
82 |
+
nn.Unflatten(1, (mel_channels, 80))
|
83 |
+
)
|
84 |
+
|
85 |
+
def reparameterize(self, mu, log_var):
|
86 |
+
std = torch.exp(0.5 * log_var)
|
87 |
+
eps = torch.randn_like(std)
|
88 |
+
return mu + eps * std
|
89 |
+
|
90 |
+
def forward(self, z):
|
91 |
+
mu = self.fc_mu(z)
|
92 |
+
log_var = self.fc_var(z)
|
93 |
+
|
94 |
+
# Reparameterization trick
|
95 |
+
z = self.reparameterize(mu, log_var)
|
96 |
+
|
97 |
+
# Decode
|
98 |
+
mel_spec = self.decoder_layers(z)
|
99 |
+
|
100 |
+
return mel_spec, mu, log_var
|
101 |
+
|
102 |
+
class TextToSpeechModel(nn.Module):
|
103 |
+
def __init__(self, text_encoder, vae_decoder, latent_dim=256):
|
104 |
+
super().__init__()
|
105 |
+
self.text_encoder = text_encoder
|
106 |
+
self.vae_decoder = vae_decoder
|
107 |
+
|
108 |
+
# Projection layer to map encoder output to latent space
|
109 |
+
self.projection = nn.Linear(text_encoder.config.hidden_size, latent_dim)
|
110 |
+
|
111 |
+
def forward(self, text_tokens):
|
112 |
+
# Encode text
|
113 |
+
encoder_output = self.text_encoder(text_tokens).last_hidden_state
|
114 |
+
|
115 |
+
# Mean pooling of encoder output
|
116 |
+
text_embedding = encoder_output.mean(dim=1)
|
117 |
+
|
118 |
+
# Project to latent space
|
119 |
+
latent_z = self.projection(text_embedding)
|
120 |
+
|
121 |
+
# Decode to mel spectrogram
|
122 |
+
mel_spec, mu, log_var = self.vae_decoder(latent_z)
|
123 |
+
|
124 |
+
return mel_spec, mu, log_var
|
125 |
+
|
126 |
+
def vae_loss(reconstruction, target, mu, log_var):
|
127 |
+
# Reconstruction loss (MSE)
|
128 |
+
recon_loss = F.mse_loss(reconstruction, target, reduction='mean')
|
129 |
+
|
130 |
+
# KL Divergence loss
|
131 |
+
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
132 |
+
|
133 |
+
return recon_loss + 0.001 * kl_loss
|
134 |
+
|
135 |
+
def train_model(num_epochs=10, accumulation_steps=16):
|
136 |
+
# Tokenizer and mel spectrogram transform
|
137 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
138 |
+
tokenizer.pad_token = tokenizer.eos_token
|
139 |
+
|
140 |
+
# Mel spectrogram configuration
|
141 |
+
mel_transform = MelSpectrogram(
|
142 |
+
sample_rate=16000,
|
143 |
+
n_mels=80,
|
144 |
+
n_fft=1024,
|
145 |
+
hop_length=256
|
146 |
+
)
|
147 |
+
|
148 |
+
# Data preparation
|
149 |
+
text_folder = './texts'
|
150 |
+
audio_folder = './audio'
|
151 |
+
|
152 |
+
# Load text and audio files
|
153 |
+
text_files = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.endswith('.txt')]
|
154 |
+
audio_files = [os.path.join(audio_folder, f) for f in os.listdir(audio_folder) if f.endswith('.wav')]
|
155 |
+
|
156 |
+
# Split dataset
|
157 |
+
train_texts, val_texts, train_audios, val_audios = train_test_split(
|
158 |
+
text_files, audio_files, test_size=0.1, random_state=42
|
159 |
+
)
|
160 |
+
|
161 |
+
# Create datasets and dataloaders
|
162 |
+
train_dataset = TextToSpeechDataset(train_texts, train_audios, tokenizer, mel_transform)
|
163 |
+
val_dataset = TextToSpeechDataset(val_texts, val_audios, tokenizer, mel_transform)
|
164 |
+
|
165 |
+
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
|
166 |
+
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
|
167 |
+
|
168 |
+
# Model components
|
169 |
+
text_encoder = GPT2Model.from_pretrained('gpt2')
|
170 |
+
vae_decoder = VAEDecoder(latent_dim=256)
|
171 |
+
|
172 |
+
# Combine into full model
|
173 |
+
model = TextToSpeechModel(text_encoder, vae_decoder)
|
174 |
+
|
175 |
+
# Device setup
|
176 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
177 |
+
model = model.to(device)
|
178 |
+
|
179 |
+
# Optimizer and scheduler
|
180 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
181 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
|
182 |
+
|
183 |
+
# Gradient scaler
|
184 |
+
scaler = GradScaler()
|
185 |
+
|
186 |
+
best_val_loss = float('inf')
|
187 |
+
|
188 |
+
# Training loop
|
189 |
+
for epoch in range(num_epochs):
|
190 |
+
model.train()
|
191 |
+
train_loss = 0
|
192 |
+
|
193 |
+
for batch_idx, (text_tokens, mel_specs) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
|
194 |
+
text_tokens = text_tokens.to(device)
|
195 |
+
mel_specs = mel_specs.to(device)
|
196 |
+
|
197 |
+
with autocast(dtype=torch.float16, device_type='cuda'):
|
198 |
+
# Forward pass
|
199 |
+
reconstructed_mel, mu, log_var = model(text_tokens)
|
200 |
+
|
201 |
+
# Compute loss
|
202 |
+
loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var)
|
203 |
+
|
204 |
+
# Scaled loss and backpropagation
|
205 |
+
loss = loss / accumulation_steps
|
206 |
+
scaler.scale(loss).backward()
|
207 |
+
|
208 |
+
if (batch_idx + 1) % accumulation_steps == 0:
|
209 |
+
scaler.step(optimizer)
|
210 |
+
scaler.update()
|
211 |
+
optimizer.zero_grad()
|
212 |
+
|
213 |
+
train_loss += loss.item()
|
214 |
+
|
215 |
+
# Validation
|
216 |
+
model.eval()
|
217 |
+
val_loss = 0
|
218 |
+
with torch.no_grad():
|
219 |
+
for text_tokens, mel_specs in tqdm(val_loader, desc=f"Validation {epoch+1}"):
|
220 |
+
text_tokens = text_tokens.to(device)
|
221 |
+
mel_specs = mel_specs.to(device)
|
222 |
+
|
223 |
+
reconstructed_mel, mu, log_var = model(text_tokens)
|
224 |
+
loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var)
|
225 |
+
val_loss += loss.item()
|
226 |
+
|
227 |
+
# Scheduler step
|
228 |
+
scheduler.step()
|
229 |
+
|
230 |
+
# Print epoch summary
|
231 |
+
print(f'Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}')
|
232 |
+
|
233 |
+
# Model saving
|
234 |
+
if val_loss < best_val_loss:
|
235 |
+
best_val_loss = val_loss
|
236 |
+
torch.save(model.state_dict(), 'best_tts_model.pth')
|
237 |
+
|
238 |
+
return model
|
239 |
+
|
240 |
+
# Run training
|
241 |
+
trained_model = train_model()
|
242 |
+
|
243 |
+
# Optional: Inference function for generating mel spectrograms
|
244 |
+
def generate_mel_spectrogram(text, model, tokenizer, device):
|
245 |
+
model.eval()
|
246 |
+
with torch.no_grad():
|
247 |
+
# Tokenize input text
|
248 |
+
text_tokens = tokenizer.encode(
|
249 |
+
text,
|
250 |
+
return_tensors="pt",
|
251 |
+
truncation=True,
|
252 |
+
padding='max_length',
|
253 |
+
max_length=512
|
254 |
+
).to(device)
|
255 |
+
|
256 |
+
# Generate mel spectrogram
|
257 |
+
mel_spec, _, _ = model(text_tokens)
|
258 |
+
|
259 |
+
return mel_spec
|
260 |
+
|
261 |
+
# Optional: Convert mel spectrogram back to audio
|
262 |
+
def mel_to_audio(mel_spec, sample_rate=16000):
|
263 |
+
# Use griffin-lim for mel spectrogram inversion
|
264 |
+
inverse_mel = InverseMelScale(sample_rate=sample_rate)
|
265 |
+
griffin_lim = GriffinLim(sample_rate=sample_rate)
|
266 |
+
|
267 |
+
# Convert mel spectrogram back to waveform
|
268 |
+
waveform = griffin_lim(inverse_mel(mel_spec))
|
269 |
+
|
270 |
+
return waveform
|
vocab.json
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
" ": 61,
|
3 |
+
"'": 81,
|
4 |
+
"-": 33,
|
5 |
+
"0": 6,
|
6 |
+
"1": 57,
|
7 |
+
"2": 54,
|
8 |
+
"3": 66,
|
9 |
+
"4": 23,
|
10 |
+
"5": 79,
|
11 |
+
"6": 46,
|
12 |
+
"_": 43,
|
13 |
+
"c": 75,
|
14 |
+
"i": 18,
|
15 |
+
"m": 30,
|
16 |
+
"o": 73,
|
17 |
+
"q": 82,
|
18 |
+
"ം": 29,
|
19 |
+
"ഃ": 38,
|
20 |
+
"അ": 63,
|
21 |
+
"ആ": 26,
|
22 |
+
"ഇ": 10,
|
23 |
+
"ഈ": 67,
|
24 |
+
"ഉ": 19,
|
25 |
+
"ഊ": 13,
|
26 |
+
"എ": 47,
|
27 |
+
"ഏ": 45,
|
28 |
+
"ഐ": 80,
|
29 |
+
"ഒ": 5,
|
30 |
+
"ഓ": 51,
|
31 |
+
"ഔ": 68,
|
32 |
+
"ക": 56,
|
33 |
+
"ഖ": 8,
|
34 |
+
"ഗ": 39,
|
35 |
+
"ഘ": 50,
|
36 |
+
"ങ": 59,
|
37 |
+
"ച": 14,
|
38 |
+
"ഛ": 35,
|
39 |
+
"ജ": 31,
|
40 |
+
"ഞ": 9,
|
41 |
+
"ട": 37,
|
42 |
+
"ഠ": 22,
|
43 |
+
"ഡ": 11,
|
44 |
+
"ഢ": 62,
|
45 |
+
"ണ": 40,
|
46 |
+
"ത": 1,
|
47 |
+
"ഥ": 49,
|
48 |
+
"ദ": 78,
|
49 |
+
"ധ": 32,
|
50 |
+
"ന": 71,
|
51 |
+
"പ": 69,
|
52 |
+
"ഫ": 20,
|
53 |
+
"ബ": 65,
|
54 |
+
"ഭ": 60,
|
55 |
+
"മ": 64,
|
56 |
+
"യ": 16,
|
57 |
+
"ര": 2,
|
58 |
+
"റ": 44,
|
59 |
+
"ല": 21,
|
60 |
+
"ള": 3,
|
61 |
+
"ഴ": 4,
|
62 |
+
"വ": 76,
|
63 |
+
"ശ": 83,
|
64 |
+
"ഷ": 17,
|
65 |
+
"സ": 27,
|
66 |
+
"ഹ": 34,
|
67 |
+
"ാ": 15,
|
68 |
+
"ി": 12,
|
69 |
+
"ീ": 52,
|
70 |
+
"ു": 28,
|
71 |
+
"ൂ": 7,
|
72 |
+
"ൃ": 74,
|
73 |
+
"െ": 24,
|
74 |
+
"േ": 55,
|
75 |
+
"ൈ": 53,
|
76 |
+
"ൊ": 72,
|
77 |
+
"ോ": 42,
|
78 |
+
"്": 36,
|
79 |
+
"ൗ": 25,
|
80 |
+
"ൺ": 70,
|
81 |
+
"ൻ": 77,
|
82 |
+
"ർ": 48,
|
83 |
+
"ൽ": 41,
|
84 |
+
"ൾ": 58,
|
85 |
+
"—": 0
|
86 |
+
}
|