jaskaran Singh
commited on
Commit
·
390d94d
1
Parent(s):
d0a22d2
indic
Browse files- .gitattributes +3 -0
- README.md +98 -18
- maha_tts/__init__.py +3 -1
- maha_tts/config.py +3 -2
- maha_tts/inference.py +74 -24
- maha_tts/models/__init__.py +0 -0
- maha_tts/models/autoregressive.py +0 -135
- maha_tts/models/diff_model.py +0 -303
- maha_tts/models/modules.py +0 -406
- maha_tts/models/vocoder.py +0 -342
- maha_tts/pretrained_models/.DS_Store +0 -0
- maha_tts/pretrained_models/Smolie-en/.DS_Store +0 -0
- maha_tts/pretrained_models/Smolie-en/s2a_latest.pt +3 -0
- maha_tts/pretrained_models/{smolie/T2S → Smolie-en}/t2s_best.pt +2 -2
- maha_tts/pretrained_models/Smolie-in/.DS_Store +0 -0
- maha_tts/pretrained_models/Smolie-in/s2a_latest.pt +3 -0
- maha_tts/pretrained_models/{smolie/S2A/s2a_latest.pt → Smolie-in/t2s_best.pt} +2 -2
- maha_tts/text/cleaners.py +2 -2
- maha_tts/text/symbols.py +11 -1
- maha_tts/utils/audio.py +2 -2
- ref_clips/2971_4275_000003_000007.wav +0 -0
- ref_clips/2971_4275_000020_000001.wav +0 -0
- ref_clips/2971_4275_000023_000010.wav +0 -0
- ref_clips/2971_4275_000049_000000.wav +0 -0
- ref_clips/2971_4275_000049_000004.wav +0 -0
- ref_clips/2971_4275_000050_000000.wav +0 -0
- requirements.txt +46 -0
- setup.py +23 -0
- tts.py +6 -4
.gitattributes
CHANGED
@@ -37,3 +37,6 @@ maha_tts/pretrained_models/smolie/T2S/t2s_best.pt filter=lfs diff=lfs merge=lfs
|
|
37 |
maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt filter=lfs diff=lfs merge=lfs -text
|
38 |
maha_tts/pretrained_models/hifigan/config.json filter=lfs diff=lfs merge=lfs -text
|
39 |
maha_tts/pretrained_models/hifigan/g_02500000 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
37 |
maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt filter=lfs diff=lfs merge=lfs -text
|
38 |
maha_tts/pretrained_models/hifigan/config.json filter=lfs diff=lfs merge=lfs -text
|
39 |
maha_tts/pretrained_models/hifigan/g_02500000 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
maha_tts/pretrained_models/hifigan filter=lfs diff=lfs merge=lfs -text
|
41 |
+
maha_tts/pretrained_models/Smolie-en filter=lfs diff=lfs merge=lfs -text
|
42 |
+
maha_tts/pretrained_models/Smolie-in filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,37 +1,37 @@
|
|
1 |
<div align="center">
|
2 |
|
3 |
<h1>MahaTTS: An Open-Source Large Speech Generation Model in the making</h1>
|
4 |
-
a Dubverse Black initiative <br> <br>
|
5 |
|
6 |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-eOQqznKWwAfMdusJ_LDtDhjIyAlSMrG?usp=sharing)
|
7 |
[![Discord Shield](https://discordapp.com/api/guilds/1162007551987171410/widget.png?style=shield)](https://discord.gg/4VGnrgpBN)
|
8 |
-
|
9 |
</div>
|
10 |
|
11 |
------
|
12 |
|
13 |
## Description
|
14 |
-
MahaTTS (Maha means 'Great' in sanskrit), is a speech generation model which is inspired from tortoise-tts, except it uses seamless M4t wav2vec2 to extract semantic tokens.
|
15 |
-
Since seamless M4t wav2vec2 is trained on multilingual data, it makes this model easier to scale on multilingual data.
|
16 |
|
17 |
-
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
## Features
|
27 |
-
|
|
|
28 |
2. Realistic Prosody and intonation
|
29 |
3. Multi-voice capabilities
|
30 |
|
31 |
-
## Current Progress
|
32 |
-
Trained on 200 hours of LibriTTS model -> 'Smolie'
|
33 |
-
|
34 |
## Installation
|
|
|
35 |
```bash
|
36 |
pip install git+https://github.com/dubverse-ai/MahaTTS.git
|
37 |
```
|
@@ -39,8 +39,88 @@ pip install git+https://github.com/dubverse-ai/MahaTTS.git
|
|
39 |
```bash
|
40 |
pip install maha-tts
|
41 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
## Roadmap
|
43 |
-
- [x] Smolie - eng
|
44 |
-
- [ ] Smolie - indic
|
45 |
-
- [ ] Optimizations for inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
|
|
|
|
|
|
|
|
|
1 |
<div align="center">
|
2 |
|
3 |
<h1>MahaTTS: An Open-Source Large Speech Generation Model in the making</h1>
|
4 |
+
a <a href = "https://black.dubverse.ai">Dubverse Black</a> initiative <br> <br>
|
5 |
|
6 |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-eOQqznKWwAfMdusJ_LDtDhjIyAlSMrG?usp=sharing)
|
7 |
[![Discord Shield](https://discordapp.com/api/guilds/1162007551987171410/widget.png?style=shield)](https://discord.gg/4VGnrgpBN)
|
|
|
8 |
</div>
|
9 |
|
10 |
------
|
11 |
|
12 |
## Description
|
|
|
|
|
13 |
|
14 |
+
MahaTTS, with Maha signifying 'Great' in Sanskrit, is a Text to Speech Model developed by [Dubverse.ai](https://dubverse.ai). We drew inspiration from the [tortoise-tts](https://github.com/neonbjb/tortoise-tts) model, but our model uniquely utilizes seamless M4t wav2vec2 for semantic token extraction. As this specific variant of wav2vec2 is trained on multilingual data, it enhances our model's scalability across different languages.
|
15 |
|
16 |
+
We are providing access to pretrained model checkpoints, which are ready for inference and available for commercial use.
|
17 |
+
|
18 |
+
<img width="993" alt="MahaTTS Architecture" src="https://github.com/dubverse-ai/MahaTTS/assets/32906806/7429d3b6-3f19-4bd8-9005-ff9e16a698f8">
|
19 |
+
|
20 |
+
## Updates
|
21 |
+
|
22 |
+
**2023-11-13**
|
23 |
+
|
24 |
+
- MahaTTS Released! Open sourced Smolie
|
25 |
+
- Community and access to new features on our [Discord](https://discord.gg/uFPrzBqyF2)
|
26 |
|
27 |
## Features
|
28 |
+
|
29 |
+
1. Multilinguality (coming soon)
|
30 |
2. Realistic Prosody and intonation
|
31 |
3. Multi-voice capabilities
|
32 |
|
|
|
|
|
|
|
33 |
## Installation
|
34 |
+
|
35 |
```bash
|
36 |
pip install git+https://github.com/dubverse-ai/MahaTTS.git
|
37 |
```
|
|
|
39 |
```bash
|
40 |
pip install maha-tts
|
41 |
```
|
42 |
+
|
43 |
+
## api usage
|
44 |
+
|
45 |
+
```bash
|
46 |
+
!gdown --folder 1-HEc3V4f6X93I8_IfqExLfL3s8I_dXGZ -q # download speakers ref files
|
47 |
+
|
48 |
+
import torch,glob
|
49 |
+
from maha_tts import load_models,infer_tts
|
50 |
+
from scipy.io.wavfile import write
|
51 |
+
from IPython.display import Audio,display
|
52 |
+
|
53 |
+
# PATH TO THE SPEAKERS WAV FILES
|
54 |
+
speaker =['/content/infer_ref_wavs/2272_152282_000019_000001/',
|
55 |
+
'/content/infer_ref_wavs/2971_4275_000049_000000/',
|
56 |
+
'/content/infer_ref_wavs/4807_26852_000062_000000/',
|
57 |
+
'/content/infer_ref_wavs/6518_66470_000014_000002/']
|
58 |
+
|
59 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
60 |
+
diff_model,ts_model,vocoder,diffuser = load_models('Smolie',device)
|
61 |
+
print('Using:',device)
|
62 |
+
|
63 |
+
speaker_num = 0 # @param ["0", "1", "2", "3"] {type:"raw"}
|
64 |
+
text = "I freakin love how Elon came to life the moment they started talking about gaming and specifically diablo, you can tell that he didn't want that part of the discussion to end, while Lex to move on to the next subject! Once a true gamer, always a true gamer!" # @param {type:"string"}
|
65 |
+
|
66 |
+
ref_clips = glob.glob(speaker[speaker_num]+'*.wav')
|
67 |
+
audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder)
|
68 |
+
|
69 |
+
write('/content/test.wav',sr,audio)
|
70 |
+
```
|
71 |
## Roadmap
|
72 |
+
- [x] Smolie - eng (trained on 200 hours of LibriTTS)
|
73 |
+
- [ ] Smolie - indic (Train on Indian languages, coming soon)
|
74 |
+
- [ ] Optimizations for inference (looking for contributors, check issues)
|
75 |
+
|
76 |
+
## Some Generated Samples
|
77 |
+
0 -> "I seriously laughed so much hahahaha (seals with headphones...) and appreciate both the interviewer and the subject. Major respect for two extraordinary humans - and in this time of gratefulness, I'm thankful for you both and this forum!"
|
78 |
+
|
79 |
+
1 -> "I freakin love how Elon came to life the moment they started talking about gaming and specifically diablo, you can tell that he didn't want that part of the discussion to end, while Lex to move on to the next subject! Once a true gamer, always a true gamer!"
|
80 |
+
|
81 |
+
2 -> "hello there! how are you?" (This one didn't work well, M1 model hallucinated)
|
82 |
+
|
83 |
+
3 -> "Who doesn't love a good scary story, something to send a chill across your skin in the middle of summer's heat or really, any other time? And this year, we're celebrating the two hundredth birthday of one of the most famous scary stories of all time: Frankenstein."
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
https://github.com/dubverse-ai/MahaTTS/assets/32906806/462ee134-5d8c-43c8-a425-3b6cabd2ff85
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
https://github.com/dubverse-ai/MahaTTS/assets/32906806/40c62402-7f65-4a35-b739-d8b8a082ad62
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
https://github.com/dubverse-ai/MahaTTS/assets/32906806/f0a9628c-ef81-450d-ab82-2f4c4626864e
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
https://github.com/dubverse-ai/MahaTTS/assets/32906806/15476151-72ea-410d-bcdc-177433df7884
|
101 |
+
|
102 |
+
|
103 |
+
## Technical Details
|
104 |
+
|
105 |
+
### Model Params
|
106 |
+
| Model (Smolie) | Parameters | Model Type | Output |
|
107 |
+
|:-------------------------:|:----------:|------------|:-----------------:|
|
108 |
+
| Text to Semantic (M1) | 69 M | Causal LM | 10,001 Tokens |
|
109 |
+
| Semantic to MelSpec(M2) | 108 M | Diffusion | 2x 80x Melspec |
|
110 |
+
| Hifi Gan Vocoder | 13 M | GAN | Audio Waveform |
|
111 |
+
|
112 |
+
### Languages Supported
|
113 |
+
| Language | Status |
|
114 |
+
| --- | :---: |
|
115 |
+
| English (en) | ✅ |
|
116 |
+
|
117 |
+
## License
|
118 |
+
|
119 |
+
MahaTTS is licensed under the Apache 2.0 License.
|
120 |
+
|
121 |
+
## 🙏 Appreciation
|
122 |
|
123 |
+
- [tortoise-tts](https://github.com/neonbjb/tortoise-tts)
|
124 |
+
- [M4t Seamless](https://github.com/facebookresearch/seamless_communication) [AudioLM](https://arxiv.org/abs/2209.03143) and many other ground-breaking papers that enabled the development of MahaTTS
|
125 |
+
- [Diffusion training](https://github.com/openai/guided-diffusion) for training diffusion model
|
126 |
+
- [Huggingface](https://huggingface.co/docs/transformers/index) for related training and inference code
|
maha_tts/__init__.py
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
from .inference import load_models,load_diffuser,infer_tts
|
|
|
|
|
|
1 |
+
from maha_tts.inference import load_models,load_diffuser,infer_tts
|
2 |
+
from maha_tts.config import config
|
3 |
+
__version__ = '1.0.0'
|
maha_tts/config.py
CHANGED
@@ -5,8 +5,9 @@ class config:
|
|
5 |
seed_value = 3407
|
6 |
|
7 |
# Text to Semantic
|
8 |
-
t2s_position =
|
9 |
-
|
|
|
10 |
# Semantic to acoustic
|
11 |
sa_timesteps_max = 1000
|
12 |
|
|
|
5 |
seed_value = 3407
|
6 |
|
7 |
# Text to Semantic
|
8 |
+
t2s_position = 4096
|
9 |
+
langs = ['english','tamil', 'telugu', 'punjabi', 'marathi', 'hindi', 'gujarati', 'bengali', 'assamese']
|
10 |
+
lang_index = {i:j for j,i in enumerate(langs)}
|
11 |
# Semantic to acoustic
|
12 |
sa_timesteps_max = 1000
|
13 |
|
maha_tts/inference.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
import torch,glob,os
|
2 |
import numpy as np
|
3 |
import torch.nn.functional as F
|
4 |
|
|
|
5 |
from librosa.filters import mel as librosa_mel_fn
|
6 |
from scipy.io.wavfile import write
|
7 |
from scipy.special import softmax
|
@@ -11,10 +12,12 @@ from maha_tts.models.vocoder import load_vocoder_model,infer_wav
|
|
11 |
from maha_tts.utils.audio import denormalize_tacotron_mel,normalize_tacotron_mel,load_wav_to_torch,dynamic_range_compression
|
12 |
from maha_tts.utils.stft import STFT
|
13 |
from maha_tts.utils.diffusion import SpacedDiffusion,get_named_beta_schedule,space_timesteps
|
14 |
-
from maha_tts.text.symbols import labels,text_labels,code_labels,text_enc,text_dec,code_enc,code_dec
|
15 |
from maha_tts.text.cleaners import english_cleaners
|
16 |
from maha_tts.config import config
|
17 |
|
|
|
|
|
18 |
stft_fn = STFT(config.filter_length, config.hop_length, config.win_length)
|
19 |
|
20 |
mel_basis = librosa_mel_fn(
|
@@ -23,13 +26,52 @@ mel_basis = librosa_mel_fn(
|
|
23 |
mel_basis = torch.from_numpy(mel_basis).float()
|
24 |
|
25 |
model_dirs= {
|
26 |
-
'Smolie':'
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
}
|
29 |
|
30 |
-
def
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def load_models(name,device=torch.device('cpu')):
|
35 |
'''
|
@@ -51,10 +93,10 @@ def load_models(name,device=torch.device('cpu')):
|
|
51 |
|
52 |
assert name in model_dirs, "no model name "+name
|
53 |
|
54 |
-
checkpoint_diff =
|
55 |
-
checkpoint_ts =
|
56 |
-
checkpoint_voco = '
|
57 |
-
voco_config_path = '
|
58 |
|
59 |
# for i in [checkpoint_diff,checkpoint_ts,checkpoint_voco,voco_config_path]:
|
60 |
if not os.path.exists(checkpoint_diff) or not os.path.exists(checkpoint_ts):
|
@@ -64,15 +106,16 @@ def load_models(name,device=torch.device('cpu')):
|
|
64 |
download_model('hifigan')
|
65 |
|
66 |
diff_model = load_diff_model(checkpoint_diff,device)
|
67 |
-
ts_model = load_TS_model(checkpoint_ts,device)
|
68 |
vocoder = load_vocoder_model(voco_config_path,checkpoint_voco,device)
|
69 |
diffuser = load_diffuser()
|
70 |
|
71 |
return diff_model,ts_model,vocoder,diffuser
|
72 |
|
73 |
-
def infer_mel(model,timeshape,code,ref_mel,diffuser,temperature=0
|
74 |
device = next(model.parameters()).device
|
75 |
code = code.to(device)
|
|
|
76 |
output_shape = (1,80,timeshape)
|
77 |
noise = torch.randn(output_shape, device=code.device) * temperature
|
78 |
mel = diffuser.p_sample_loop(model, output_shape, noise=noise,
|
@@ -84,17 +127,18 @@ def generate_semantic_tokens(
|
|
84 |
text,
|
85 |
model,
|
86 |
ref_mels,
|
|
|
87 |
temp = 0.7,
|
88 |
top_p= None,
|
89 |
-
top_k=
|
90 |
n_tot_steps = 1000,
|
91 |
device = None
|
92 |
):
|
93 |
semb = []
|
94 |
with torch.no_grad():
|
95 |
-
for n in range(n_tot_steps):
|
96 |
-
x = get_inputs(text,semb,ref_mels,device)
|
97 |
-
_,result = model(**x)
|
98 |
relevant_logits = result[0,:,-1]
|
99 |
if top_p is not None:
|
100 |
# faster to convert to numpy
|
@@ -125,9 +169,13 @@ def generate_semantic_tokens(
|
|
125 |
semb = torch.tensor([int(i) for i in semb[:-1]])
|
126 |
return semb,result
|
127 |
|
128 |
-
def get_inputs(text,semb=[],ref_mels=[],device=torch.device('cpu')):
|
129 |
text = text.lower()
|
130 |
-
|
|
|
|
|
|
|
|
|
131 |
semb_ids=[code_enc['<SST>']]+[code_enc[i] for i in semb]#+[tok_enc['<EST>']]
|
132 |
|
133 |
input_ids = text_ids+semb_ids
|
@@ -166,7 +214,7 @@ def get_mel(filepath):
|
|
166 |
energy = torch.norm(magnitudes, dim=1).squeeze(0)
|
167 |
return melspec,list(energy)
|
168 |
|
169 |
-
def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder):
|
170 |
'''
|
171 |
Generate audio from the given text using a text-to-speech (TTS) pipeline.
|
172 |
|
@@ -193,6 +241,7 @@ def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder):
|
|
193 |
Example usage:
|
194 |
audio, sampling_rate = infer_tts("Hello, how are you?", ref_clips, diffuser, diff_model, ts_model, vocoder)
|
195 |
'''
|
|
|
196 |
text = english_cleaners(text)
|
197 |
ref_mels = get_ref_mels(ref_clips)
|
198 |
with torch.no_grad():
|
@@ -200,20 +249,21 @@ def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder):
|
|
200 |
text,
|
201 |
ts_model,
|
202 |
ref_mels,
|
|
|
203 |
temp = 0.7,
|
204 |
top_p= 0.8,
|
205 |
top_k= 5,
|
206 |
n_tot_steps = 1000,
|
207 |
-
device =
|
208 |
)
|
209 |
mel = infer_mel(diff_model,int(((sem_tok.shape[-1] * 320 / 16000) * 22050/256)+1),sem_tok.unsqueeze(0) + 1,
|
210 |
-
ref_mels,diffuser,temperature=
|
211 |
|
212 |
audio = infer_wav(mel,vocoder)
|
213 |
|
214 |
return audio,config.sampling_rate
|
215 |
|
216 |
-
def load_diffuser(timesteps = 100,
|
217 |
'''
|
218 |
Load and configure a diffuser for denoising and guidance in the diffusion model.
|
219 |
|
@@ -227,10 +277,10 @@ def load_diffuser(timesteps = 100, gudiance=3):
|
|
227 |
Description:
|
228 |
The `load_diffuser` function initializes a diffuser with specific settings for denoising and guidance.
|
229 |
'''
|
230 |
-
betas = get_named_beta_schedule('
|
231 |
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(1000, [timesteps]), model_mean_type='epsilon',
|
232 |
model_var_type='learned_range', loss_type='rescaled_mse', betas=betas,
|
233 |
-
conditioning_free=True, conditioning_free_k=
|
234 |
diffuser.training=False
|
235 |
return diffuser
|
236 |
|
|
|
1 |
+
import torch,glob,os,requests
|
2 |
import numpy as np
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
+
from tqdm import tqdm
|
6 |
from librosa.filters import mel as librosa_mel_fn
|
7 |
from scipy.io.wavfile import write
|
8 |
from scipy.special import softmax
|
|
|
12 |
from maha_tts.utils.audio import denormalize_tacotron_mel,normalize_tacotron_mel,load_wav_to_torch,dynamic_range_compression
|
13 |
from maha_tts.utils.stft import STFT
|
14 |
from maha_tts.utils.diffusion import SpacedDiffusion,get_named_beta_schedule,space_timesteps
|
15 |
+
from maha_tts.text.symbols import labels,text_labels,text_labels_en,code_labels,text_enc,text_dec,code_enc,code_dec,text_enc_en,text_dec_en
|
16 |
from maha_tts.text.cleaners import english_cleaners
|
17 |
from maha_tts.config import config
|
18 |
|
19 |
+
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'maha_tts', 'models')
|
20 |
+
DEFAULT_MODELS_DIR = '/Users/jaskaransingh/Desktop/MahaTTS/models/'
|
21 |
stft_fn = STFT(config.filter_length, config.hop_length, config.win_length)
|
22 |
|
23 |
mel_basis = librosa_mel_fn(
|
|
|
26 |
mel_basis = torch.from_numpy(mel_basis).float()
|
27 |
|
28 |
model_dirs= {
|
29 |
+
'Smolie':['https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt',
|
30 |
+
'https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/smolie/T2S/t2s_best.pt'],
|
31 |
+
'Smolie-en':[''],
|
32 |
+
'Smolie-in':[''],
|
33 |
+
'hifigan':['https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/hifigan/g_02500000',
|
34 |
+
'https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/hifigan/config.json']
|
35 |
}
|
36 |
|
37 |
+
def download_file(url, filename):
|
38 |
+
response = requests.get(url, stream=True)
|
39 |
+
total_size = int(response.headers.get('content-length', 0))
|
40 |
+
|
41 |
+
# Check if the response was successful (status code 200)
|
42 |
+
response.raise_for_status()
|
43 |
+
|
44 |
+
with open(filename, 'wb') as file, tqdm(
|
45 |
+
desc=filename,
|
46 |
+
total=total_size,
|
47 |
+
unit='B',
|
48 |
+
unit_scale=True,
|
49 |
+
unit_divisor=1024,
|
50 |
+
) as bar:
|
51 |
+
for data in response.iter_content(chunk_size=1024):
|
52 |
+
# Write data to the file
|
53 |
+
file.write(data)
|
54 |
+
# Update the progress bar
|
55 |
+
bar.update(len(data))
|
56 |
|
57 |
+
print(f"Download complete: {filename}\n")
|
58 |
+
|
59 |
+
def download_model(name):
|
60 |
+
print('Downloading ',name," ....")
|
61 |
+
checkpoint_diff = os.path.join(DEFAULT_MODELS_DIR,name,'s2a_latest.pt')
|
62 |
+
checkpoint_ts = os.path.join(DEFAULT_MODELS_DIR,name,'t2s_best.pt')
|
63 |
+
checkpoint_voco = os.path.join(DEFAULT_MODELS_DIR,'hifigan','g_02500000')
|
64 |
+
voco_config_path = os.path.join(DEFAULT_MODELS_DIR,'hifigan','config.json')
|
65 |
+
|
66 |
+
os.makedirs(os.path.join(DEFAULT_MODELS_DIR,name),exist_ok=True)
|
67 |
+
|
68 |
+
if name == 'hifigan':
|
69 |
+
download_file(model_dirs[name][0],checkpoint_voco)
|
70 |
+
download_file(model_dirs[name][1],voco_config_path)
|
71 |
+
|
72 |
+
else:
|
73 |
+
download_file(model_dirs[name][0],checkpoint_diff)
|
74 |
+
download_file(model_dirs[name][1],checkpoint_ts)
|
75 |
|
76 |
def load_models(name,device=torch.device('cpu')):
|
77 |
'''
|
|
|
93 |
|
94 |
assert name in model_dirs, "no model name "+name
|
95 |
|
96 |
+
checkpoint_diff = os.path.join(DEFAULT_MODELS_DIR,name,'s2a_latest.pt')
|
97 |
+
checkpoint_ts = os.path.join(DEFAULT_MODELS_DIR,name,'t2s_best.pt')
|
98 |
+
checkpoint_voco = os.path.join(DEFAULT_MODELS_DIR,'hifigan','g_02500000')
|
99 |
+
voco_config_path = os.path.join(DEFAULT_MODELS_DIR,'hifigan','config.json')
|
100 |
|
101 |
# for i in [checkpoint_diff,checkpoint_ts,checkpoint_voco,voco_config_path]:
|
102 |
if not os.path.exists(checkpoint_diff) or not os.path.exists(checkpoint_ts):
|
|
|
106 |
download_model('hifigan')
|
107 |
|
108 |
diff_model = load_diff_model(checkpoint_diff,device)
|
109 |
+
ts_model = load_TS_model(checkpoint_ts,device,name)
|
110 |
vocoder = load_vocoder_model(voco_config_path,checkpoint_voco,device)
|
111 |
diffuser = load_diffuser()
|
112 |
|
113 |
return diff_model,ts_model,vocoder,diffuser
|
114 |
|
115 |
+
def infer_mel(model,timeshape,code,ref_mel,diffuser,temperature=1.0):
|
116 |
device = next(model.parameters()).device
|
117 |
code = code.to(device)
|
118 |
+
ref_mel =ref_mel.to(device)
|
119 |
output_shape = (1,80,timeshape)
|
120 |
noise = torch.randn(output_shape, device=code.device) * temperature
|
121 |
mel = diffuser.p_sample_loop(model, output_shape, noise=noise,
|
|
|
127 |
text,
|
128 |
model,
|
129 |
ref_mels,
|
130 |
+
language=None,
|
131 |
temp = 0.7,
|
132 |
top_p= None,
|
133 |
+
top_k= 1,
|
134 |
n_tot_steps = 1000,
|
135 |
device = None
|
136 |
):
|
137 |
semb = []
|
138 |
with torch.no_grad():
|
139 |
+
for n in tqdm(range(n_tot_steps)):
|
140 |
+
x = get_inputs(text,semb,ref_mels,device,model.name)
|
141 |
+
_,result = model(**x,language=language)
|
142 |
relevant_logits = result[0,:,-1]
|
143 |
if top_p is not None:
|
144 |
# faster to convert to numpy
|
|
|
169 |
semb = torch.tensor([int(i) for i in semb[:-1]])
|
170 |
return semb,result
|
171 |
|
172 |
+
def get_inputs(text,semb=[],ref_mels=[],device=torch.device('cpu'),name = 'Smolie-in'):
|
173 |
text = text.lower()
|
174 |
+
if name=='Smolie-en':
|
175 |
+
text_ids=[text_enc_en['<S>']]+[text_enc_en[i] for i in text.strip()]+[text_enc_en['<E>']]
|
176 |
+
else:
|
177 |
+
text_ids=[text_enc['<S>']]+[text_enc[i] for i in text.strip()]+[text_enc['<E>']]
|
178 |
+
|
179 |
semb_ids=[code_enc['<SST>']]+[code_enc[i] for i in semb]#+[tok_enc['<EST>']]
|
180 |
|
181 |
input_ids = text_ids+semb_ids
|
|
|
214 |
energy = torch.norm(magnitudes, dim=1).squeeze(0)
|
215 |
return melspec,list(energy)
|
216 |
|
217 |
+
def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder,language=None):
|
218 |
'''
|
219 |
Generate audio from the given text using a text-to-speech (TTS) pipeline.
|
220 |
|
|
|
241 |
Example usage:
|
242 |
audio, sampling_rate = infer_tts("Hello, how are you?", ref_clips, diffuser, diff_model, ts_model, vocoder)
|
243 |
'''
|
244 |
+
device = next(ts_model.parameters()).device
|
245 |
text = english_cleaners(text)
|
246 |
ref_mels = get_ref_mels(ref_clips)
|
247 |
with torch.no_grad():
|
|
|
249 |
text,
|
250 |
ts_model,
|
251 |
ref_mels,
|
252 |
+
language,
|
253 |
temp = 0.7,
|
254 |
top_p= 0.8,
|
255 |
top_k= 5,
|
256 |
n_tot_steps = 1000,
|
257 |
+
device = device
|
258 |
)
|
259 |
mel = infer_mel(diff_model,int(((sem_tok.shape[-1] * 320 / 16000) * 22050/256)+1),sem_tok.unsqueeze(0) + 1,
|
260 |
+
normalize_tacotron_mel(ref_mels),diffuser,temperature=0.5)
|
261 |
|
262 |
audio = infer_wav(mel,vocoder)
|
263 |
|
264 |
return audio,config.sampling_rate
|
265 |
|
266 |
+
def load_diffuser(timesteps = 100, guidance=3):
|
267 |
'''
|
268 |
Load and configure a diffuser for denoising and guidance in the diffusion model.
|
269 |
|
|
|
277 |
Description:
|
278 |
The `load_diffuser` function initializes a diffuser with specific settings for denoising and guidance.
|
279 |
'''
|
280 |
+
betas = get_named_beta_schedule('linear',config.sa_timesteps_max)
|
281 |
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(1000, [timesteps]), model_mean_type='epsilon',
|
282 |
model_var_type='learned_range', loss_type='rescaled_mse', betas=betas,
|
283 |
+
conditioning_free=True, conditioning_free_k=guidance)
|
284 |
diffuser.training=False
|
285 |
return diffuser
|
286 |
|
maha_tts/models/__init__.py
DELETED
File without changes
|
maha_tts/models/autoregressive.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
Inspiration taken from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/models/autoregressive.py
|
3 |
-
'''
|
4 |
-
import os,sys
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import torch.optim as optim
|
9 |
-
import functools
|
10 |
-
|
11 |
-
from typing import Any
|
12 |
-
from torch.utils.data import Dataset,DataLoader
|
13 |
-
from transformers import GPT2Tokenizer,GPT2Config, GPT2Model, GPT2LMHeadModel
|
14 |
-
from tqdm import tqdm
|
15 |
-
from maha_tts.config import config
|
16 |
-
from maha_tts.text.symbols import labels,code_labels,text_labels
|
17 |
-
from maha_tts.models.modules import GST
|
18 |
-
|
19 |
-
def null_position_embeddings(range, dim):
|
20 |
-
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
21 |
-
|
22 |
-
class TS_model(nn.Module):
|
23 |
-
def __init__(self,n_embed = 512, n_layer = 16, n_head = 8):
|
24 |
-
super(TS_model,self).__init__()
|
25 |
-
|
26 |
-
self.vocab_size=len(labels)
|
27 |
-
self.n_positions=config.t2s_position
|
28 |
-
self.n_embed=n_embed
|
29 |
-
self.n_layer=n_layer
|
30 |
-
self.n_head=n_head
|
31 |
-
|
32 |
-
self.config = GPT2Config(vocab_size=self.vocab_size,n_positions=self.n_positions,n_embd=self.n_embed,n_layer=self.n_layer,n_head=self.n_head)
|
33 |
-
self.gpt = GPT2Model(self.config)
|
34 |
-
del self.gpt.wpe
|
35 |
-
self.gpt.wpe = functools.partial(null_position_embeddings, dim=self.n_embed)
|
36 |
-
# Built-in token embeddings are unused.
|
37 |
-
del self.gpt.wte
|
38 |
-
self.GST = GST(model_channels=self.n_embed,num_heads=self.n_head,in_channels=config.n_mel_channels,k=1)
|
39 |
-
self.text_head = nn.Linear(self.n_embed,len(text_labels))
|
40 |
-
self.code_head = nn.Linear(self.n_embed,len(code_labels))
|
41 |
-
|
42 |
-
self.text_positional_embed = LearnedPositionEmbeddings(self.n_positions,self.n_embed)
|
43 |
-
self.code_positional_embed = LearnedPositionEmbeddings(self.n_positions,self.n_embed)
|
44 |
-
|
45 |
-
self.text_embed = nn.Embedding(len(text_labels),self.n_embed)
|
46 |
-
self.code_embed = nn.Embedding(len(code_labels),self.n_embed)
|
47 |
-
self.final_norm = nn.LayerNorm(self.n_embed)
|
48 |
-
|
49 |
-
def get_speaker_latent(self, ref_mels):
|
50 |
-
ref_mels = ref_mels.unsqueeze(1) if len(
|
51 |
-
ref_mels.shape) == 3 else ref_mels
|
52 |
-
|
53 |
-
conds = []
|
54 |
-
for j in range(ref_mels.shape[1]):
|
55 |
-
conds.append(self.GST(ref_mels[:, j,:,:]))
|
56 |
-
|
57 |
-
conds = torch.cat(conds, dim=-1)
|
58 |
-
conds = conds.mean(dim=-1)
|
59 |
-
|
60 |
-
return conds.unsqueeze(1)
|
61 |
-
|
62 |
-
def forward(self,text_ids,codes_ids = None,speaker_embed=None,ref_clips=None,return_loss = False):
|
63 |
-
assert speaker_embed is not None or ref_clips is not None
|
64 |
-
text_embed = self.text_embed(text_ids)
|
65 |
-
text_embed += self.text_positional_embed(text_embed)
|
66 |
-
|
67 |
-
code_embed = None
|
68 |
-
code_probs= None
|
69 |
-
|
70 |
-
if codes_ids is not None:
|
71 |
-
code_embed = self.code_embed(codes_ids)
|
72 |
-
code_embed+= self.code_positional_embed(code_embed)
|
73 |
-
|
74 |
-
if ref_clips is not None:
|
75 |
-
speaker_embed = self.get_speaker_latent(ref_clips)
|
76 |
-
|
77 |
-
text_embed,code_embed = self.get_logits(speaker_embed=speaker_embed,text_embed=text_embed,code_embed=code_embed)
|
78 |
-
|
79 |
-
text_probs = self.text_head(text_embed).permute(0,2,1)
|
80 |
-
|
81 |
-
if codes_ids is not None:
|
82 |
-
code_probs = self.code_head(code_embed).permute(0,2,1)
|
83 |
-
|
84 |
-
if return_loss:
|
85 |
-
loss_text = F.cross_entropy(text_probs[:,:,:-1], text_ids[:,1:].long(), reduce=False)
|
86 |
-
loss_mel = F.cross_entropy(code_probs[:,:,:-1], codes_ids[:,1:].long(), reduce=False)
|
87 |
-
return loss_text,loss_mel,code_probs
|
88 |
-
|
89 |
-
return text_probs,code_probs
|
90 |
-
|
91 |
-
|
92 |
-
def get_logits(self,speaker_embed,text_embed,code_embed=None):
|
93 |
-
|
94 |
-
if code_embed is not None:
|
95 |
-
embed = torch.cat([speaker_embed,text_embed,code_embed],dim=1)
|
96 |
-
else:
|
97 |
-
embed = torch.cat([speaker_embed,text_embed],dim=1)
|
98 |
-
|
99 |
-
gpt_output = self.gpt(inputs_embeds=embed, return_dict=True)
|
100 |
-
enc = gpt_output.last_hidden_state[:, 1:]
|
101 |
-
enc = self.final_norm(enc)
|
102 |
-
if code_embed is not None:
|
103 |
-
return enc[:,:text_embed.shape[1]],enc[:,-code_embed.shape[1]:]
|
104 |
-
|
105 |
-
return enc[:,:text_embed.shape[1]],None
|
106 |
-
|
107 |
-
class LearnedPositionEmbeddings(nn.Module):
|
108 |
-
def __init__(self, seq_len, model_dim, init=.02):
|
109 |
-
super().__init__()
|
110 |
-
self.emb = nn.Embedding(seq_len, model_dim)
|
111 |
-
# Initializing this way is standard for GPT-2
|
112 |
-
self.emb.weight.data.normal_(mean=0.0, std=init)
|
113 |
-
|
114 |
-
def forward(self, x):
|
115 |
-
sl = x.shape[1]
|
116 |
-
return self.emb(torch.arange(0, sl, device=x.device))
|
117 |
-
|
118 |
-
def get_fixed_embedding(self, ind, dev):
|
119 |
-
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
120 |
-
|
121 |
-
def load_TS_model(checkpoint,device):
|
122 |
-
sem_model= TS_model(n_embed = 512, n_layer = 16, n_head = 8)
|
123 |
-
sem_model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')),strict=False)
|
124 |
-
sem_model.eval().to(device)
|
125 |
-
|
126 |
-
return sem_model
|
127 |
-
|
128 |
-
if __name__ == '__main__':
|
129 |
-
model=TS_model(n_embed = 256, n_layer = 6, n_head = 4)
|
130 |
-
|
131 |
-
text_ids = torch.randint(0,100,(5,20))
|
132 |
-
code_ids = torch.randint(0,100,(5,200))
|
133 |
-
speaker_embed = torch.randn((5,1,256))
|
134 |
-
|
135 |
-
output=model(text_ids=text_ids,speaker_embed=speaker_embed,codes_ids=code_ids,return_loss=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maha_tts/models/diff_model.py
DELETED
@@ -1,303 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
inspiration taken from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/models/diffusion_decoder.py
|
3 |
-
'''
|
4 |
-
import sys
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import math
|
9 |
-
|
10 |
-
from maha_tts.config import config
|
11 |
-
from torch import autocast
|
12 |
-
from maha_tts.models.modules import QuartzNetBlock,AttentionBlock,mySequential,normalization,SCBD,SqueezeExcite,GST
|
13 |
-
|
14 |
-
def timestep_embedding(timesteps, dim, max_period=10000):
|
15 |
-
"""
|
16 |
-
Create sinusoidal timestep embeddings.
|
17 |
-
|
18 |
-
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
19 |
-
These may be fractional.
|
20 |
-
:param dim: the dimension of the output.
|
21 |
-
:param max_period: controls the minimum frequency of the embeddings.
|
22 |
-
:return: an [N x dim] Tensor of positional embeddings.
|
23 |
-
"""
|
24 |
-
half = dim // 2
|
25 |
-
freqs = torch.exp(
|
26 |
-
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
27 |
-
).to(device=timesteps.device)
|
28 |
-
args = timesteps[:, None].float() * freqs[None]
|
29 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
30 |
-
if dim % 2:
|
31 |
-
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
32 |
-
return embedding
|
33 |
-
|
34 |
-
class TimestepBlock(nn.Module):
|
35 |
-
def forward(self, x, emb):
|
36 |
-
"""
|
37 |
-
Apply the module to `x` given `emb` timestep embeddings.
|
38 |
-
"""
|
39 |
-
|
40 |
-
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
41 |
-
def forward(self, x, emb):
|
42 |
-
for layer in self:
|
43 |
-
if isinstance(layer, TimestepBlock):
|
44 |
-
x = layer(x, emb)
|
45 |
-
else:
|
46 |
-
x = layer(x)
|
47 |
-
return x
|
48 |
-
|
49 |
-
class QuartzNetBlock(TimestepBlock):
|
50 |
-
'''Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
|
51 |
-
if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
|
52 |
-
'''
|
53 |
-
def __init__(self,nin,nout,emb_channels,kernel_size=3,dropout=0.1,R=1,se=True,ratio=8,separable=False,bias=True,use_scale_shift_norm=True):
|
54 |
-
super(QuartzNetBlock,self).__init__()
|
55 |
-
self.use_scale_shift_norm = use_scale_shift_norm
|
56 |
-
self.se=se
|
57 |
-
self.in_layers = mySequential(
|
58 |
-
nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
|
59 |
-
normalization(nout) #nn.BatchNorm1d(nout,eps)
|
60 |
-
)
|
61 |
-
|
62 |
-
self.residual=mySequential(
|
63 |
-
nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
|
64 |
-
normalization(nout) #nn.BatchNorm1d(nout,eps)
|
65 |
-
)
|
66 |
-
|
67 |
-
nin=nout
|
68 |
-
model=[]
|
69 |
-
|
70 |
-
self.emb_layers = nn.Sequential(
|
71 |
-
nn.SiLU(),
|
72 |
-
nn.Linear(
|
73 |
-
emb_channels,
|
74 |
-
2 * nout if use_scale_shift_norm else nout,
|
75 |
-
),
|
76 |
-
)
|
77 |
-
|
78 |
-
for i in range(R-1):
|
79 |
-
model.append(SCBD(nin,nout,kernel_size,dropout,bias=bias))
|
80 |
-
nin=nout
|
81 |
-
|
82 |
-
if separable:
|
83 |
-
model.append(SCBD(nin,nout,kernel_size,dropout,rd=False,bias=bias))
|
84 |
-
else:
|
85 |
-
model.append(SCBD(nin,nout,kernel_size,dropout,rd=False,separable=False,bias=bias))
|
86 |
-
|
87 |
-
self.model=mySequential(*model)
|
88 |
-
if self.se:
|
89 |
-
self.se_layer=SqueezeExcite(nin,ratio)
|
90 |
-
|
91 |
-
self.mout= mySequential(nn.SiLU(),nn.Dropout(dropout))
|
92 |
-
|
93 |
-
def forward(self,x,emb,mask=None):
|
94 |
-
x_new=self.in_layers(x)
|
95 |
-
emb = self.emb_layers(emb)
|
96 |
-
while len(emb.shape) < len(x_new.shape):
|
97 |
-
emb = emb[..., None]
|
98 |
-
scale, shift = torch.chunk(emb, 2, dim=1)
|
99 |
-
x_new = x_new * (1 + scale) + shift
|
100 |
-
y,_=self.model(x_new)
|
101 |
-
|
102 |
-
if self.se:
|
103 |
-
y,_=self.se_layer(y,mask)
|
104 |
-
y+=self.residual(x)
|
105 |
-
y=self.mout(y)
|
106 |
-
|
107 |
-
return y
|
108 |
-
|
109 |
-
class QuartzAttn(TimestepBlock):
|
110 |
-
def __init__(self, model_channels, dropout, num_heads):
|
111 |
-
super().__init__()
|
112 |
-
self.resblk = QuartzNetBlock(model_channels, model_channels, model_channels,dropout=dropout,use_scale_shift_norm=True)
|
113 |
-
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
114 |
-
|
115 |
-
def forward(self, x, time_emb):
|
116 |
-
y = self.resblk(x, time_emb)
|
117 |
-
return self.attn(y)
|
118 |
-
|
119 |
-
class QuartzNet9x5(nn.Module):
|
120 |
-
def __init__(self,model_channels,num_heads,enable_fp16=False):
|
121 |
-
super(QuartzNet9x5,self).__init__()
|
122 |
-
self.enable_fp16 = enable_fp16
|
123 |
-
|
124 |
-
self.conv1=QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=3,dropout=0.1,R=3)
|
125 |
-
kernels=[5,7,9,13,15,17]
|
126 |
-
quartznet=[]
|
127 |
-
attn=[]
|
128 |
-
for i in kernels:
|
129 |
-
quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=i,dropout=0.1,R=5,se=True))
|
130 |
-
attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
|
131 |
-
kernels=[21,23,25]
|
132 |
-
quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=21,dropout=0.1,R=5,se=True))
|
133 |
-
attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
|
134 |
-
|
135 |
-
for i in kernels[1:]:
|
136 |
-
quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=i,dropout=0.1,R=5,se=True))
|
137 |
-
attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
|
138 |
-
self.quartznet= nn.ModuleList(quartznet)
|
139 |
-
self.attn = nn.ModuleList(attn)
|
140 |
-
self.conv3=nn.Conv1d(model_channels, model_channels, 1, padding='same')
|
141 |
-
|
142 |
-
|
143 |
-
def forward(self, x, time_emb):
|
144 |
-
x = self.conv1(x,time_emb)
|
145 |
-
# with autocast(x.device.type, enabled=self.enable_fp16):
|
146 |
-
for n,(layer,attn) in enumerate(zip(self.quartznet,self.attn)):
|
147 |
-
x = layer(x,time_emb) #256 dim
|
148 |
-
x = attn(x)
|
149 |
-
x = self.conv3(x.float())
|
150 |
-
return x
|
151 |
-
|
152 |
-
class DiffModel(nn.Module):
|
153 |
-
|
154 |
-
def __init__(
|
155 |
-
self,
|
156 |
-
input_channels=80,
|
157 |
-
output_channels=160,
|
158 |
-
model_channels=512,
|
159 |
-
num_heads=8,
|
160 |
-
dropout=0.0,
|
161 |
-
multispeaker = True,
|
162 |
-
condition_free_per=0.1,
|
163 |
-
training = False,
|
164 |
-
ar_active = False,
|
165 |
-
in_latent_channels = 10004
|
166 |
-
):
|
167 |
-
|
168 |
-
super().__init__()
|
169 |
-
self.input_channels = input_channels
|
170 |
-
self.model_channels = model_channels
|
171 |
-
self.output_channels = output_channels
|
172 |
-
self.num_heads = num_heads
|
173 |
-
self.dropout = dropout
|
174 |
-
self.condition_free_per = condition_free_per
|
175 |
-
self.training = training
|
176 |
-
self.multispeaker = multispeaker
|
177 |
-
self.ar_active = ar_active
|
178 |
-
self.in_latent_channels = in_latent_channels
|
179 |
-
|
180 |
-
if not self.ar_active:
|
181 |
-
self.code_emb = nn.Embedding(config.semantic_model_centroids+1,model_channels)
|
182 |
-
self.code_converter = mySequential(
|
183 |
-
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
184 |
-
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
185 |
-
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
186 |
-
)
|
187 |
-
else:
|
188 |
-
self.code_converter = mySequential(
|
189 |
-
nn.Conv1d(self.in_latent_channels, model_channels, 3, padding=1),
|
190 |
-
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
191 |
-
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
192 |
-
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
193 |
-
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
194 |
-
)
|
195 |
-
if self.multispeaker:
|
196 |
-
self.GST = GST(model_channels,num_heads)
|
197 |
-
|
198 |
-
self.code_norm = normalization(model_channels)
|
199 |
-
self.time_norm = normalization(model_channels)
|
200 |
-
self.noise_norm = normalization(model_channels)
|
201 |
-
self.code_time_norm = normalization(model_channels)
|
202 |
-
|
203 |
-
# self.code_latent = []
|
204 |
-
self.time_embed = mySequential(
|
205 |
-
nn.Linear(model_channels, model_channels),
|
206 |
-
nn.SiLU(),
|
207 |
-
nn.Linear(model_channels, model_channels),)
|
208 |
-
|
209 |
-
self.input_block = nn.Conv1d(input_channels,model_channels,3,1,1)
|
210 |
-
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
211 |
-
|
212 |
-
self.code_time = TimestepEmbedSequential(QuartzAttn(model_channels, dropout, num_heads),QuartzAttn(model_channels, dropout, num_heads),QuartzAttn(model_channels, dropout, num_heads))
|
213 |
-
self.layers = QuartzNet9x5(model_channels,num_heads)
|
214 |
-
|
215 |
-
self.out = nn.Sequential(
|
216 |
-
normalization(model_channels),
|
217 |
-
nn.SiLU(),
|
218 |
-
nn.Conv1d(model_channels, output_channels, 3, padding=1),
|
219 |
-
)
|
220 |
-
|
221 |
-
def get_speaker_latent(self, ref_mels):
|
222 |
-
ref_mels = ref_mels.unsqueeze(1) if len(
|
223 |
-
ref_mels.shape) == 3 else ref_mels
|
224 |
-
|
225 |
-
conds = []
|
226 |
-
for j in range(ref_mels.shape[1]):
|
227 |
-
conds.append(self.GST(ref_mels[:, j,:,:]))
|
228 |
-
|
229 |
-
conds = torch.cat(conds, dim=-1)
|
230 |
-
conds = conds.mean(dim=-1)
|
231 |
-
|
232 |
-
return conds.unsqueeze(2)
|
233 |
-
|
234 |
-
def forward(self ,x,t,code_emb,ref_clips=None,speaker_latents=None,conditioning_free=False):
|
235 |
-
time_embed = self.time_norm(self.time_embed(timestep_embedding(t.unsqueeze(-1),self.model_channels)).permute(0,2,1)).squeeze(2)
|
236 |
-
if conditioning_free:
|
237 |
-
code_embed = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
238 |
-
else:
|
239 |
-
if not self.ar_active:
|
240 |
-
code_embed = self.code_norm(self.code_converter(self.code_emb(code_emb).permute(0,2,1)))
|
241 |
-
else:
|
242 |
-
code_embed = self.code_norm(self.code_converter(code_emb))
|
243 |
-
if self.multispeaker:
|
244 |
-
assert speaker_latents is not None or ref_clips is not None
|
245 |
-
if ref_clips is not None:
|
246 |
-
speaker_latents = self.get_speaker_latent(ref_clips)
|
247 |
-
cond_scale, cond_shift = torch.chunk(speaker_latents, 2, dim=1)
|
248 |
-
code_embed = code_embed * (1 + cond_scale) + cond_shift
|
249 |
-
if self.training and self.condition_free_per > 0:
|
250 |
-
unconditioned_batches = torch.rand((code_embed.shape[0], 1, 1),
|
251 |
-
device=code_embed.device) < self.condition_free_per
|
252 |
-
code_embed = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_embed.shape[0], 1, 1),
|
253 |
-
code_embed)
|
254 |
-
|
255 |
-
expanded_code_emb = F.interpolate(code_embed, size=x.shape[-1], mode='nearest') #try different modes
|
256 |
-
|
257 |
-
x_cond = self.code_time_norm(self.code_time(expanded_code_emb,time_embed))
|
258 |
-
|
259 |
-
x = self.noise_norm(self.input_block(x))
|
260 |
-
x += x_cond
|
261 |
-
x = self.layers(x, time_embed)
|
262 |
-
out = self.out(x)
|
263 |
-
return out
|
264 |
-
|
265 |
-
def load_diff_model(checkpoint,device,model_channels=512,ar_active=False,len_code_labels=10004):
|
266 |
-
diff_model = DiffModel(input_channels=80,
|
267 |
-
output_channels=160,
|
268 |
-
model_channels=512,
|
269 |
-
num_heads=8,
|
270 |
-
dropout=0.15,
|
271 |
-
condition_free_per=0.15,
|
272 |
-
multispeaker=True,
|
273 |
-
training=False,
|
274 |
-
ar_active=ar_active,
|
275 |
-
in_latent_channels = len_code_labels)
|
276 |
-
|
277 |
-
# diff_model.load_state_dict(torch.load('/content/LibriTTS_fp64_10k/S2A/_latest.pt',map_location=torch.device('cpu')),strict=True)
|
278 |
-
diff_model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')),strict=True)
|
279 |
-
diff_model=diff_model.eval().to(device)
|
280 |
-
return diff_model
|
281 |
-
|
282 |
-
|
283 |
-
if __name__ == '__main__':
|
284 |
-
|
285 |
-
device = torch.device('cpu')
|
286 |
-
diff_model = DiffModel(input_channels=80,
|
287 |
-
output_channels=160,
|
288 |
-
model_channels=1024,
|
289 |
-
num_heads=8,
|
290 |
-
dropout=0.1,
|
291 |
-
num_layers=8,
|
292 |
-
enable_fp16=True,
|
293 |
-
condition_free_per=0.1,
|
294 |
-
multispeaker=True,
|
295 |
-
training=True).to(device)
|
296 |
-
|
297 |
-
batch_Size = 32
|
298 |
-
timeseries = 800
|
299 |
-
from torchinfo import summary
|
300 |
-
summary(diff_model, input_data={'x': torch.randn(batch_Size, 80, timeseries).to(device),
|
301 |
-
'ref_clips': torch.randn(batch_Size,3, 80, timeseries).to(device),
|
302 |
-
't':torch.LongTensor(size=[batch_Size,]).to(device),
|
303 |
-
'code_emb':torch.randint(0,201,(batch_Size,timeseries)).to(device)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maha_tts/models/modules.py
DELETED
@@ -1,406 +0,0 @@
|
|
1 |
-
import torch,math
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torch.nn.init as init
|
5 |
-
from einops import rearrange, repeat
|
6 |
-
|
7 |
-
def zero_module(module):
|
8 |
-
"""
|
9 |
-
Zero out the parameters of a module and return it.
|
10 |
-
Using it for Zero Convolutions
|
11 |
-
"""
|
12 |
-
for p in module.parameters():
|
13 |
-
p.detach().zero_()
|
14 |
-
return module
|
15 |
-
|
16 |
-
|
17 |
-
class GroupNorm32(nn.GroupNorm):
|
18 |
-
def forward(self, x):
|
19 |
-
return super().forward(x.float()).type(x.dtype)
|
20 |
-
|
21 |
-
|
22 |
-
def normalization(channels):
|
23 |
-
"""
|
24 |
-
Make a standard normalization layer. of groups ranging from 2 to 32.
|
25 |
-
|
26 |
-
:param channels: number of input channels.
|
27 |
-
:return: an nn.Module for normalization.
|
28 |
-
"""
|
29 |
-
groups = 32
|
30 |
-
if channels <= 16:
|
31 |
-
groups = 8
|
32 |
-
elif channels <= 64:
|
33 |
-
groups = 16
|
34 |
-
while channels % groups != 0:
|
35 |
-
groups = int(groups / 2)
|
36 |
-
assert groups > 2
|
37 |
-
return GroupNorm32(groups, channels)
|
38 |
-
|
39 |
-
|
40 |
-
class mySequential(nn.Sequential):
|
41 |
-
'''Using this to pass mask variable to nn layers
|
42 |
-
'''
|
43 |
-
def forward(self, *inputs):
|
44 |
-
for module in self._modules.values():
|
45 |
-
if type(inputs) == tuple:
|
46 |
-
inputs = module(*inputs)
|
47 |
-
else:
|
48 |
-
inputs = module(inputs)
|
49 |
-
return inputs
|
50 |
-
|
51 |
-
class SepConv1D(nn.Module):
|
52 |
-
'''Depth wise separable Convolution layer with mask
|
53 |
-
'''
|
54 |
-
def __init__(self,nin,nout,kernel_size,stride=1,dilation=1,padding_mode='same',bias=True):
|
55 |
-
super(SepConv1D,self).__init__()
|
56 |
-
self.conv1=nn.Conv1d(nin, nin, kernel_size=kernel_size, stride=stride,groups=nin,dilation=dilation,padding=padding_mode,bias=bias)
|
57 |
-
self.conv2=nn.Conv1d(nin,nout,kernel_size=1,stride=1,padding=padding_mode,bias=bias)
|
58 |
-
|
59 |
-
def forward(self,x,mask=None):
|
60 |
-
if mask is not None:
|
61 |
-
x = x * mask.unsqueeze(1).to(device=x.device)
|
62 |
-
x=self.conv1(x)
|
63 |
-
x=self.conv2(x)
|
64 |
-
return x,mask
|
65 |
-
|
66 |
-
class Conv1DBN(nn.Module):
|
67 |
-
def __init__(self,nin,nout,kernel_size,stride=1,dilation=1,dropout=0.1,padding_mode='same',bias=False):
|
68 |
-
super(Conv1DBN,self).__init__()
|
69 |
-
self.conv1=nn.Conv1d(nin, nout, kernel_size=kernel_size, stride=stride,padding=padding_mode,dilation=dilation,bias=bias)
|
70 |
-
self.bn=nn.BatchNorm1d(nout)
|
71 |
-
self.drop=nn.Dropout(dropout)
|
72 |
-
|
73 |
-
def forward(self,x,mask=None):
|
74 |
-
if mask is not None:
|
75 |
-
x = x * mask.unsqueeze(1).to(device=x.device)
|
76 |
-
x=self.conv1(x)
|
77 |
-
x=self.bn(x)
|
78 |
-
x=F.relu(x)
|
79 |
-
x=self.drop(x)
|
80 |
-
return x,mask
|
81 |
-
|
82 |
-
class Conv1d(nn.Module):
|
83 |
-
'''normal conv1d with mask
|
84 |
-
'''
|
85 |
-
def __init__(self,nin,nout,kernel_size,padding,bias=True):
|
86 |
-
super(Conv1d,self).__init__()
|
87 |
-
self.l=nn.Conv1d(nin,nout,kernel_size,padding=padding,bias=bias)
|
88 |
-
def forward(self,x,mask):
|
89 |
-
if mask is not None:
|
90 |
-
x = x * mask.unsqueeze(1).to(device=x.device)
|
91 |
-
y=self.l(x)
|
92 |
-
return y,mask
|
93 |
-
|
94 |
-
class SqueezeExcite(nn.Module):
|
95 |
-
'''Let the CNN decide how to add across channels
|
96 |
-
'''
|
97 |
-
def __init__(self,nin,ratio=8):
|
98 |
-
super(SqueezeExcite,self).__init__()
|
99 |
-
self.nin=nin
|
100 |
-
self.ratio=ratio
|
101 |
-
|
102 |
-
self.fc=mySequential(
|
103 |
-
nn.Linear(nin,nin//ratio,bias=True),nn.SiLU(inplace=True),nn.Linear(nin//ratio,nin,bias=True)
|
104 |
-
)
|
105 |
-
|
106 |
-
def forward(self,x,mask=None):
|
107 |
-
if mask is None:
|
108 |
-
mask = torch.ones((x.shape[0],x.shape[-1]),dtype=torch.bool).to(x.device)
|
109 |
-
mask=~mask
|
110 |
-
x=x.float()
|
111 |
-
x.masked_fill_(mask.unsqueeze(1), 0.0)
|
112 |
-
mask=~mask
|
113 |
-
y = (torch.sum(x, dim=-1, keepdim=True) / mask.unsqueeze(1).sum(dim=-1, keepdim=True)).type(x.dtype)
|
114 |
-
# y=torch.mean(x,-1,keepdim=True)
|
115 |
-
y=y.transpose(1, -1)
|
116 |
-
y=self.fc(y)
|
117 |
-
y=torch.sigmoid(y)
|
118 |
-
y=y.transpose(1, -1)
|
119 |
-
y= x * y
|
120 |
-
return y,mask
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
class SCBD(nn.Module):
|
125 |
-
'''SeparableConv1D + Batchnorm + Dropout, Generally use it for middle layers and resnet
|
126 |
-
'''
|
127 |
-
def __init__(self,nin,nout,kernel_size,p=0.1,rd=True,separable=True,bias=True):
|
128 |
-
super(SCBD,self).__init__()
|
129 |
-
if separable:
|
130 |
-
self.SC=SepConv1D(nin,nout,kernel_size,bias=bias)
|
131 |
-
else:
|
132 |
-
self.SC=Conv1d(nin,nout,kernel_size,padding='same',bias=bias)
|
133 |
-
|
134 |
-
if rd: #relu and Dropout
|
135 |
-
self.mout=mySequential(normalization(nout),nn.SiLU(), # nn.BatchNorm1d(nout,eps)
|
136 |
-
nn.Dropout(p))
|
137 |
-
else:
|
138 |
-
self.mout=normalization(nout) # nn.BatchNorm1d(nout,eps)
|
139 |
-
|
140 |
-
def forward(self,x,mask=None):
|
141 |
-
if mask is not None:
|
142 |
-
x = x * mask.unsqueeze(1).to(device=x.device)
|
143 |
-
x,_= self.SC(x,mask)
|
144 |
-
y = self.mout(x)
|
145 |
-
return y,mask
|
146 |
-
|
147 |
-
class QuartzNetBlock(nn.Module):
|
148 |
-
'''Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
|
149 |
-
if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
|
150 |
-
'''
|
151 |
-
def __init__(self,nin,nout,kernel_size,dropout=0.1,R=5,se=False,ratio=8,separable=False,bias=True):
|
152 |
-
super(QuartzNetBlock,self).__init__()
|
153 |
-
self.se=se
|
154 |
-
self.residual=mySequential(
|
155 |
-
nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
|
156 |
-
normalization(nout) #nn.BatchNorm1d(nout,eps)
|
157 |
-
)
|
158 |
-
model=[]
|
159 |
-
|
160 |
-
for i in range(R-1):
|
161 |
-
model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,bias=bias))
|
162 |
-
nin=nout
|
163 |
-
|
164 |
-
if separable:
|
165 |
-
model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,rd=False,bias=bias))
|
166 |
-
else:
|
167 |
-
model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,rd=False,separable=False,bias=bias))
|
168 |
-
self.model=mySequential(*model)
|
169 |
-
|
170 |
-
if self.se:
|
171 |
-
self.se_layer=SqueezeExcite(nin,ratio)
|
172 |
-
|
173 |
-
self.mout= mySequential(nn.SiLU(),nn.Dropout(dropout))
|
174 |
-
|
175 |
-
def forward(self,x,mask=None):
|
176 |
-
if mask is not None:
|
177 |
-
x = x * mask.unsqueeze(1).to(device=x.device)
|
178 |
-
y,_=self.model(x,mask)
|
179 |
-
if self.se:
|
180 |
-
y,_=self.se_layer(y,mask)
|
181 |
-
y+=self.residual(x)
|
182 |
-
y=self.mout(y)
|
183 |
-
return y,mask
|
184 |
-
|
185 |
-
class QKVAttentionLegacy(nn.Module):
|
186 |
-
"""
|
187 |
-
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
188 |
-
"""
|
189 |
-
|
190 |
-
def __init__(self, n_heads):
|
191 |
-
super().__init__()
|
192 |
-
self.n_heads = n_heads
|
193 |
-
|
194 |
-
def forward(self, qkv, mask=None, rel_pos=None):
|
195 |
-
"""
|
196 |
-
Apply QKV attention.
|
197 |
-
|
198 |
-
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
199 |
-
:return: an [N x (H * C) x T] tensor after attention.
|
200 |
-
"""
|
201 |
-
bs, width, length = qkv.shape
|
202 |
-
assert width % (3 * self.n_heads) == 0
|
203 |
-
ch = width // (3 * self.n_heads)
|
204 |
-
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
205 |
-
scale = 1 / math.sqrt(math.sqrt(ch))
|
206 |
-
weight = torch.einsum(
|
207 |
-
"bct,bcs->bts", q * scale, k * scale
|
208 |
-
) # More stable with f16 than dividing afterwards
|
209 |
-
if rel_pos is not None:
|
210 |
-
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
211 |
-
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
212 |
-
if mask is not None:
|
213 |
-
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
214 |
-
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
215 |
-
weight = weight * mask
|
216 |
-
a = torch.einsum("bts,bcs->bct", weight, v)
|
217 |
-
|
218 |
-
return a.reshape(bs, -1, length)
|
219 |
-
|
220 |
-
class AttentionBlock(nn.Module):
|
221 |
-
"""
|
222 |
-
An attention block that allows spatial positions to attend to each other.
|
223 |
-
|
224 |
-
Originally ported from here, but adapted to the N-d case.
|
225 |
-
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
226 |
-
"""
|
227 |
-
|
228 |
-
def __init__(
|
229 |
-
self,
|
230 |
-
channels,
|
231 |
-
num_heads=1,
|
232 |
-
num_head_channels=-1,
|
233 |
-
do_checkpoint=True,
|
234 |
-
relative_pos_embeddings=False,
|
235 |
-
):
|
236 |
-
super().__init__()
|
237 |
-
self.channels = channels
|
238 |
-
self.do_checkpoint = do_checkpoint
|
239 |
-
if num_head_channels == -1:
|
240 |
-
self.num_heads = num_heads
|
241 |
-
else:
|
242 |
-
assert (
|
243 |
-
channels % num_head_channels == 0
|
244 |
-
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
245 |
-
self.num_heads = channels // num_head_channels
|
246 |
-
self.norm = normalization(channels)
|
247 |
-
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
248 |
-
# split heads before split qkv
|
249 |
-
self.attention = QKVAttentionLegacy(self.num_heads)
|
250 |
-
|
251 |
-
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) # no effect of attention in the inital stages.
|
252 |
-
# if relative_pos_embeddings:
|
253 |
-
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) #need to read about this, vit and swin transformers
|
254 |
-
# self.relative_pos_embeddings = FixedPositionalEmbedding(dim=channels)
|
255 |
-
# else:
|
256 |
-
# self.relative_pos_embeddings = None
|
257 |
-
|
258 |
-
def forward(self, x, mask=None):
|
259 |
-
b, c, *spatial = x.shape
|
260 |
-
x = x.reshape(b, c, -1)
|
261 |
-
qkv = self.qkv(self.norm(x))
|
262 |
-
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
263 |
-
h = self.proj_out(h)
|
264 |
-
return (x + h).reshape(b, c, *spatial)
|
265 |
-
|
266 |
-
class AbsolutePositionalEmbedding(nn.Module):
|
267 |
-
def __init__(self, dim, max_seq_len):
|
268 |
-
super().__init__()
|
269 |
-
self.scale = dim ** -0.5
|
270 |
-
self.emb = nn.Embedding(max_seq_len, dim)
|
271 |
-
|
272 |
-
def forward(self, x):
|
273 |
-
n = torch.arange(x.shape[1], device=x.device)
|
274 |
-
pos_emb = self.emb(n)
|
275 |
-
pos_emb = rearrange(pos_emb, 'n d -> () n d')
|
276 |
-
return pos_emb * self.scale
|
277 |
-
|
278 |
-
|
279 |
-
class FixedPositionalEmbedding(nn.Module):
|
280 |
-
def __init__(self, dim):
|
281 |
-
super().__init__()
|
282 |
-
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
283 |
-
self.register_buffer('inv_freq', inv_freq)
|
284 |
-
|
285 |
-
def forward(self, x, seq_dim=1, offset=0):
|
286 |
-
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
287 |
-
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
288 |
-
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
289 |
-
return rearrange(emb, 'n d -> () n d')
|
290 |
-
|
291 |
-
class RelativePositionBias(nn.Module):
|
292 |
-
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
|
293 |
-
super().__init__()
|
294 |
-
self.scale = scale
|
295 |
-
self.causal = causal
|
296 |
-
self.num_buckets = num_buckets
|
297 |
-
self.max_distance = max_distance
|
298 |
-
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
299 |
-
|
300 |
-
@staticmethod
|
301 |
-
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
302 |
-
ret = 0
|
303 |
-
n = -relative_position
|
304 |
-
if not causal:
|
305 |
-
num_buckets //= 2
|
306 |
-
ret += (n < 0).long() * num_buckets
|
307 |
-
n = torch.abs(n)
|
308 |
-
else:
|
309 |
-
n = torch.max(n, torch.zeros_like(n))
|
310 |
-
|
311 |
-
max_exact = num_buckets // 2
|
312 |
-
is_small = n < max_exact
|
313 |
-
|
314 |
-
val_if_large = max_exact + (
|
315 |
-
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
316 |
-
).long()
|
317 |
-
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
318 |
-
|
319 |
-
ret += torch.where(is_small, n, val_if_large)
|
320 |
-
return ret
|
321 |
-
|
322 |
-
def forward(self, qk_dots):
|
323 |
-
i, j, device = *qk_dots.shape[-2:], qk_dots.device
|
324 |
-
q_pos = torch.arange(i, dtype=torch.long, device=device)
|
325 |
-
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
326 |
-
rel_pos = k_pos[None, :] - q_pos[:, None]
|
327 |
-
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
|
328 |
-
max_distance=self.max_distance)
|
329 |
-
values = self.relative_attention_bias(rp_bucket)
|
330 |
-
bias = rearrange(values, 'i j h -> () h i j')
|
331 |
-
return qk_dots + (bias * self.scale)
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
class MultiHeadAttention(nn.Module):
|
336 |
-
'''
|
337 |
-
only for GST
|
338 |
-
input:
|
339 |
-
query --- [N, T_q, query_dim]
|
340 |
-
key --- [N, T_k, key_dim]
|
341 |
-
output:
|
342 |
-
out --- [N, T_q, num_units]
|
343 |
-
'''
|
344 |
-
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
345 |
-
super().__init__()
|
346 |
-
self.num_units = num_units
|
347 |
-
self.num_heads = num_heads
|
348 |
-
self.key_dim = key_dim
|
349 |
-
|
350 |
-
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
351 |
-
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
352 |
-
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
353 |
-
|
354 |
-
def forward(self, query, key):
|
355 |
-
querys = self.W_query(query) # [N, T_q, num_units]
|
356 |
-
keys = self.W_key(key) # [N, T_k, num_units]
|
357 |
-
values = self.W_value(key)
|
358 |
-
|
359 |
-
split_size = self.num_units // self.num_heads
|
360 |
-
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
361 |
-
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
362 |
-
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
363 |
-
|
364 |
-
# score = softmax(QK^T / (d_k ** 0.5))
|
365 |
-
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
366 |
-
scores = scores / (self.key_dim ** 0.5)
|
367 |
-
scores = F.softmax(scores, dim=3)
|
368 |
-
|
369 |
-
# out = score * V
|
370 |
-
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
371 |
-
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
372 |
-
|
373 |
-
return out
|
374 |
-
|
375 |
-
|
376 |
-
class GST(nn.Module):
|
377 |
-
def __init__(self,model_channels=512,num_heads=8,in_channels=80,k=2):
|
378 |
-
super(GST,self).__init__()
|
379 |
-
self.model_channels=model_channels
|
380 |
-
self.num_heads=num_heads
|
381 |
-
|
382 |
-
self.reference_encoder=nn.Sequential(
|
383 |
-
nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
384 |
-
nn.Conv1d(model_channels, model_channels*k,3,padding=1,stride=2),
|
385 |
-
AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
386 |
-
AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
387 |
-
AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
388 |
-
AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
389 |
-
AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False)
|
390 |
-
)
|
391 |
-
|
392 |
-
def forward(self,x):
|
393 |
-
x=self.reference_encoder(x)
|
394 |
-
return x
|
395 |
-
|
396 |
-
|
397 |
-
if __name__ == '__main__':
|
398 |
-
device = torch.device('cpu')
|
399 |
-
m = GST(512,10).to(device)
|
400 |
-
mels = torch.rand((16,80,1000)).to(device)
|
401 |
-
|
402 |
-
o = m(mels)
|
403 |
-
print(o.shape,'final output')
|
404 |
-
|
405 |
-
from torchinfo import summary
|
406 |
-
summary(m, input_data={'x': torch.randn(16,80,500).to(device)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maha_tts/models/vocoder.py
DELETED
@@ -1,342 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
copde from https://github.com/jik876/hifi-gan/blob/master/models.py
|
3 |
-
'''
|
4 |
-
|
5 |
-
import json,os
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import torch.nn as nn
|
9 |
-
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
10 |
-
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
11 |
-
# from utils import init_weights, get_padding
|
12 |
-
|
13 |
-
LRELU_SLOPE = 0.1
|
14 |
-
|
15 |
-
class AttrDict(dict):
|
16 |
-
def __init__(self, *args, **kwargs):
|
17 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
18 |
-
self.__dict__ = self
|
19 |
-
|
20 |
-
def init_weights(m, mean=0.0, std=0.01):
|
21 |
-
classname = m.__class__.__name__
|
22 |
-
if classname.find("Conv") != -1:
|
23 |
-
m.weight.data.normal_(mean, std)
|
24 |
-
|
25 |
-
|
26 |
-
def apply_weight_norm(m):
|
27 |
-
classname = m.__class__.__name__
|
28 |
-
if classname.find("Conv") != -1:
|
29 |
-
weight_norm(m)
|
30 |
-
|
31 |
-
|
32 |
-
def get_padding(kernel_size, dilation=1):
|
33 |
-
return int((kernel_size*dilation - dilation)/2)
|
34 |
-
|
35 |
-
|
36 |
-
class ResBlock1(torch.nn.Module):
|
37 |
-
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
38 |
-
super(ResBlock1, self).__init__()
|
39 |
-
self.h = h
|
40 |
-
self.convs1 = nn.ModuleList([
|
41 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
42 |
-
padding=get_padding(kernel_size, dilation[0]))),
|
43 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
44 |
-
padding=get_padding(kernel_size, dilation[1]))),
|
45 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
46 |
-
padding=get_padding(kernel_size, dilation[2])))
|
47 |
-
])
|
48 |
-
self.convs1.apply(init_weights)
|
49 |
-
|
50 |
-
self.convs2 = nn.ModuleList([
|
51 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
52 |
-
padding=get_padding(kernel_size, 1))),
|
53 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
54 |
-
padding=get_padding(kernel_size, 1))),
|
55 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
56 |
-
padding=get_padding(kernel_size, 1)))
|
57 |
-
])
|
58 |
-
self.convs2.apply(init_weights)
|
59 |
-
|
60 |
-
def forward(self, x):
|
61 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
62 |
-
xt = F.leaky_relu(x, LRELU_SLOPE)
|
63 |
-
xt = c1(xt)
|
64 |
-
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
65 |
-
xt = c2(xt)
|
66 |
-
x = xt + x
|
67 |
-
return x
|
68 |
-
|
69 |
-
def remove_weight_norm(self):
|
70 |
-
for l in self.convs1:
|
71 |
-
remove_weight_norm(l)
|
72 |
-
for l in self.convs2:
|
73 |
-
remove_weight_norm(l)
|
74 |
-
|
75 |
-
|
76 |
-
class ResBlock2(torch.nn.Module):
|
77 |
-
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
78 |
-
super(ResBlock2, self).__init__()
|
79 |
-
self.h = h
|
80 |
-
self.convs = nn.ModuleList([
|
81 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
82 |
-
padding=get_padding(kernel_size, dilation[0]))),
|
83 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
84 |
-
padding=get_padding(kernel_size, dilation[1])))
|
85 |
-
])
|
86 |
-
self.convs.apply(init_weights)
|
87 |
-
|
88 |
-
def forward(self, x):
|
89 |
-
for c in self.convs:
|
90 |
-
xt = F.leaky_relu(x, LRELU_SLOPE)
|
91 |
-
xt = c(xt)
|
92 |
-
x = xt + x
|
93 |
-
return x
|
94 |
-
|
95 |
-
def remove_weight_norm(self):
|
96 |
-
for l in self.convs:
|
97 |
-
remove_weight_norm(l)
|
98 |
-
|
99 |
-
|
100 |
-
class Generator(torch.nn.Module):
|
101 |
-
def __init__(self, h):
|
102 |
-
super(Generator, self).__init__()
|
103 |
-
self.h = h
|
104 |
-
self.num_kernels = len(h.resblock_kernel_sizes)
|
105 |
-
self.num_upsamples = len(h.upsample_rates)
|
106 |
-
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
107 |
-
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
108 |
-
|
109 |
-
self.ups = nn.ModuleList()
|
110 |
-
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
111 |
-
self.ups.append(weight_norm(
|
112 |
-
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
113 |
-
k, u, padding=(k-u)//2)))
|
114 |
-
|
115 |
-
self.resblocks = nn.ModuleList()
|
116 |
-
for i in range(len(self.ups)):
|
117 |
-
ch = h.upsample_initial_channel//(2**(i+1))
|
118 |
-
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
119 |
-
self.resblocks.append(resblock(h, ch, k, d))
|
120 |
-
|
121 |
-
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
122 |
-
self.ups.apply(init_weights)
|
123 |
-
self.conv_post.apply(init_weights)
|
124 |
-
|
125 |
-
def forward(self, x):
|
126 |
-
x = self.conv_pre(x)
|
127 |
-
for i in range(self.num_upsamples):
|
128 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
129 |
-
x = self.ups[i](x)
|
130 |
-
xs = None
|
131 |
-
for j in range(self.num_kernels):
|
132 |
-
if xs is None:
|
133 |
-
xs = self.resblocks[i*self.num_kernels+j](x)
|
134 |
-
else:
|
135 |
-
xs += self.resblocks[i*self.num_kernels+j](x)
|
136 |
-
x = xs / self.num_kernels
|
137 |
-
x = F.leaky_relu(x)
|
138 |
-
x = self.conv_post(x)
|
139 |
-
x = torch.tanh(x)
|
140 |
-
|
141 |
-
return x
|
142 |
-
|
143 |
-
def remove_weight_norm(self):
|
144 |
-
# print('Removing weight norm...')
|
145 |
-
for l in self.ups:
|
146 |
-
remove_weight_norm(l)
|
147 |
-
for l in self.resblocks:
|
148 |
-
l.remove_weight_norm()
|
149 |
-
remove_weight_norm(self.conv_pre)
|
150 |
-
remove_weight_norm(self.conv_post)
|
151 |
-
|
152 |
-
|
153 |
-
class DiscriminatorP(torch.nn.Module):
|
154 |
-
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
155 |
-
super(DiscriminatorP, self).__init__()
|
156 |
-
self.period = period
|
157 |
-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
158 |
-
self.convs = nn.ModuleList([
|
159 |
-
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
160 |
-
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
161 |
-
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
162 |
-
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
163 |
-
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
164 |
-
])
|
165 |
-
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
166 |
-
|
167 |
-
def forward(self, x):
|
168 |
-
fmap = []
|
169 |
-
|
170 |
-
# 1d to 2d
|
171 |
-
b, c, t = x.shape
|
172 |
-
if t % self.period != 0: # pad first
|
173 |
-
n_pad = self.period - (t % self.period)
|
174 |
-
x = F.pad(x, (0, n_pad), "reflect")
|
175 |
-
t = t + n_pad
|
176 |
-
x = x.view(b, c, t // self.period, self.period)
|
177 |
-
|
178 |
-
for l in self.convs:
|
179 |
-
x = l(x)
|
180 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
181 |
-
fmap.append(x)
|
182 |
-
x = self.conv_post(x)
|
183 |
-
fmap.append(x)
|
184 |
-
x = torch.flatten(x, 1, -1)
|
185 |
-
|
186 |
-
return x, fmap
|
187 |
-
|
188 |
-
|
189 |
-
class MultiPeriodDiscriminator(torch.nn.Module):
|
190 |
-
def __init__(self):
|
191 |
-
super(MultiPeriodDiscriminator, self).__init__()
|
192 |
-
self.discriminators = nn.ModuleList([
|
193 |
-
DiscriminatorP(2),
|
194 |
-
DiscriminatorP(3),
|
195 |
-
DiscriminatorP(5),
|
196 |
-
DiscriminatorP(7),
|
197 |
-
DiscriminatorP(11),
|
198 |
-
])
|
199 |
-
|
200 |
-
def forward(self, y, y_hat):
|
201 |
-
y_d_rs = []
|
202 |
-
y_d_gs = []
|
203 |
-
fmap_rs = []
|
204 |
-
fmap_gs = []
|
205 |
-
for i, d in enumerate(self.discriminators):
|
206 |
-
y_d_r, fmap_r = d(y)
|
207 |
-
y_d_g, fmap_g = d(y_hat)
|
208 |
-
y_d_rs.append(y_d_r)
|
209 |
-
fmap_rs.append(fmap_r)
|
210 |
-
y_d_gs.append(y_d_g)
|
211 |
-
fmap_gs.append(fmap_g)
|
212 |
-
|
213 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
214 |
-
|
215 |
-
|
216 |
-
class DiscriminatorS(torch.nn.Module):
|
217 |
-
def __init__(self, use_spectral_norm=False):
|
218 |
-
super(DiscriminatorS, self).__init__()
|
219 |
-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
220 |
-
self.convs = nn.ModuleList([
|
221 |
-
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
222 |
-
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
223 |
-
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
224 |
-
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
225 |
-
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
226 |
-
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
227 |
-
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
228 |
-
])
|
229 |
-
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
230 |
-
|
231 |
-
def forward(self, x):
|
232 |
-
fmap = []
|
233 |
-
for l in self.convs:
|
234 |
-
x = l(x)
|
235 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
236 |
-
fmap.append(x)
|
237 |
-
x = self.conv_post(x)
|
238 |
-
fmap.append(x)
|
239 |
-
x = torch.flatten(x, 1, -1)
|
240 |
-
|
241 |
-
return x, fmap
|
242 |
-
|
243 |
-
|
244 |
-
class MultiScaleDiscriminator(torch.nn.Module):
|
245 |
-
def __init__(self):
|
246 |
-
super(MultiScaleDiscriminator, self).__init__()
|
247 |
-
self.discriminators = nn.ModuleList([
|
248 |
-
DiscriminatorS(use_spectral_norm=True),
|
249 |
-
DiscriminatorS(),
|
250 |
-
DiscriminatorS(),
|
251 |
-
])
|
252 |
-
self.meanpools = nn.ModuleList([
|
253 |
-
AvgPool1d(4, 2, padding=2),
|
254 |
-
AvgPool1d(4, 2, padding=2)
|
255 |
-
])
|
256 |
-
|
257 |
-
def forward(self, y, y_hat):
|
258 |
-
y_d_rs = []
|
259 |
-
y_d_gs = []
|
260 |
-
fmap_rs = []
|
261 |
-
fmap_gs = []
|
262 |
-
for i, d in enumerate(self.discriminators):
|
263 |
-
if i != 0:
|
264 |
-
y = self.meanpools[i-1](y)
|
265 |
-
y_hat = self.meanpools[i-1](y_hat)
|
266 |
-
y_d_r, fmap_r = d(y)
|
267 |
-
y_d_g, fmap_g = d(y_hat)
|
268 |
-
y_d_rs.append(y_d_r)
|
269 |
-
fmap_rs.append(fmap_r)
|
270 |
-
y_d_gs.append(y_d_g)
|
271 |
-
fmap_gs.append(fmap_g)
|
272 |
-
|
273 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
274 |
-
|
275 |
-
|
276 |
-
def feature_loss(fmap_r, fmap_g):
|
277 |
-
loss = 0
|
278 |
-
for dr, dg in zip(fmap_r, fmap_g):
|
279 |
-
for rl, gl in zip(dr, dg):
|
280 |
-
loss += torch.mean(torch.abs(rl - gl))
|
281 |
-
|
282 |
-
return loss*2
|
283 |
-
|
284 |
-
|
285 |
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
286 |
-
loss = 0
|
287 |
-
r_losses = []
|
288 |
-
g_losses = []
|
289 |
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
290 |
-
r_loss = torch.mean((1-dr)**2)
|
291 |
-
g_loss = torch.mean(dg**2)
|
292 |
-
loss += (r_loss + g_loss)
|
293 |
-
r_losses.append(r_loss.item())
|
294 |
-
g_losses.append(g_loss.item())
|
295 |
-
|
296 |
-
return loss, r_losses, g_losses
|
297 |
-
|
298 |
-
|
299 |
-
def generator_loss(disc_outputs):
|
300 |
-
loss = 0
|
301 |
-
gen_losses = []
|
302 |
-
for dg in disc_outputs:
|
303 |
-
l = torch.mean((1-dg)**2)
|
304 |
-
gen_losses.append(l)
|
305 |
-
loss += l
|
306 |
-
|
307 |
-
return loss, gen_losses
|
308 |
-
|
309 |
-
def load_checkpoint(filepath, device):
|
310 |
-
assert os.path.isfile(filepath)
|
311 |
-
checkpoint_dict = torch.load(filepath, map_location=device)
|
312 |
-
return checkpoint_dict
|
313 |
-
|
314 |
-
def load_vocoder_model(config_path,checkpoint_path,device):
|
315 |
-
# config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
|
316 |
-
with open(config_path) as f:
|
317 |
-
data = f.read()
|
318 |
-
|
319 |
-
global h
|
320 |
-
json_config = json.loads(data)
|
321 |
-
h = AttrDict(json_config)
|
322 |
-
|
323 |
-
torch.manual_seed(h.seed)
|
324 |
-
|
325 |
-
generator = Generator(h).to(device)
|
326 |
-
|
327 |
-
state_dict_g = load_checkpoint(checkpoint_path, device)
|
328 |
-
generator.load_state_dict(state_dict_g['generator'])
|
329 |
-
|
330 |
-
generator.eval()
|
331 |
-
generator.remove_weight_norm()
|
332 |
-
|
333 |
-
return generator
|
334 |
-
|
335 |
-
def infer_wav(mel,generator):
|
336 |
-
MAX_WAV_VALUE =32768.0
|
337 |
-
with torch.no_grad():
|
338 |
-
y_g_hat = generator(mel)
|
339 |
-
audio = y_g_hat.squeeze()
|
340 |
-
audio = audio * MAX_WAV_VALUE
|
341 |
-
audio = audio.cpu().numpy().astype('int16')
|
342 |
-
return audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maha_tts/pretrained_models/.DS_Store
CHANGED
Binary files a/maha_tts/pretrained_models/.DS_Store and b/maha_tts/pretrained_models/.DS_Store differ
|
|
maha_tts/pretrained_models/Smolie-en/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
maha_tts/pretrained_models/Smolie-en/s2a_latest.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1cb3aef9bebda0535dce135de3ae5f23f62ec3890eed87469dfe4a9a07f0f98
|
3 |
+
size 1720934888
|
maha_tts/pretrained_models/{smolie/T2S → Smolie-en}/t2s_best.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6be1b489366ebbd35e55404be875804d380b0430587319f67e592da7ba1b5240
|
3 |
+
size 276143363
|
maha_tts/pretrained_models/Smolie-in/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
maha_tts/pretrained_models/Smolie-in/s2a_latest.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce73f611d8071f69111b71363fdad9465d84c4431fc26dc6b6de4595591c3305
|
3 |
+
size 1720934441
|
maha_tts/pretrained_models/{smolie/S2A/s2a_latest.pt → Smolie-in/t2s_best.pt}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c867f8a11f364b4cf543b42335e0a7f0450078693c66539296de1adcf2f27e6
|
3 |
+
size 823446386
|
maha_tts/text/cleaners.py
CHANGED
@@ -135,8 +135,8 @@ def transliteration_cleaners(text):
|
|
135 |
|
136 |
def english_cleaners(text):
|
137 |
'''Pipeline for English text, including number and abbreviation expansion.'''
|
138 |
-
text = convert_to_ascii(text)
|
139 |
-
text = lowercase(text)
|
140 |
text = expand_numbers(text)
|
141 |
text = expand_abbreviations(text)
|
142 |
text = collapse_whitespace(text)
|
|
|
135 |
|
136 |
def english_cleaners(text):
|
137 |
'''Pipeline for English text, including number and abbreviation expansion.'''
|
138 |
+
# text = convert_to_ascii(text)
|
139 |
+
# text = lowercase(text)
|
140 |
text = expand_numbers(text)
|
141 |
text = expand_abbreviations(text)
|
142 |
text = collapse_whitespace(text)
|
maha_tts/text/symbols.py
CHANGED
@@ -2,12 +2,18 @@ import sys
|
|
2 |
from maha_tts.config import config
|
3 |
|
4 |
labels=" abcdefghijklmnopqrstuvwxyz.,:;'()?!\""
|
5 |
-
|
|
|
|
|
6 |
labels= [i for i in labels]
|
|
|
7 |
|
8 |
text_labels = [i for i in labels]
|
9 |
text_labels+='<S>','<E>','<PAD>'
|
10 |
|
|
|
|
|
|
|
11 |
code_labels= [str(i) for i in range(config.semantic_model_centroids)]
|
12 |
labels+=code_labels
|
13 |
code_labels+='<SST>','<EST>','<PAD>'
|
@@ -21,6 +27,10 @@ tok_dec = {i:j for i,j in enumerate(labels)}
|
|
21 |
text_enc = {j:i for i,j in enumerate(text_labels)}
|
22 |
text_dec = {i:j for i,j in enumerate(text_labels)}
|
23 |
|
|
|
|
|
|
|
|
|
24 |
#code encdec
|
25 |
code_enc = {j:i for i,j in enumerate(code_labels)}
|
26 |
code_dec = {i:j for i,j in enumerate(code_labels)}
|
|
|
2 |
from maha_tts.config import config
|
3 |
|
4 |
labels=" abcdefghijklmnopqrstuvwxyz.,:;'()?!\""
|
5 |
+
labels_en=" !\"'(),-.:;?[]abcdefghijklmnopqrstuvwxyzàâèéêü’“”"
|
6 |
+
labels='''ଊతూിਮ০य़లഢਪਟକఝૂएड`यঢअచଢ଼ਧ—ତলશರଖच,பવड़ષंಈಮਤਇଥkखഗబ= इਸಣਹછ™ୟ.ोೀৎುഊଳંർਘମഴఙसଗൃlଝਜఇഓਐভയಅಠభാടਔಒ೧পஜaૅૠএଲ৯eകँ৭àৱऊટഒਗহિేயీെஈଓഭೊাੌಙ१ଈःസठખm‘ొऍಿcശrట।ऱଋઘਛெਬಂङಹஞ਼ભ১"એੂചಸગಷ়ଁമಓtஒઉಪs్-pଛ›ढ+ಆ'বનধৰউીଅઝ੍ೂʼൂఔfતषഖঢ়৬ਖक़ਵషணझപળଔઞੇವௗઁത২xెഥख़iটਲધಔೇீથ*ഝॅঃஓूఒীనਜ਼எુுహौ९ൗౌফഔોhஔণంफ़ఋçଯઊൽଆ’ୁைഛ२&ঁണ़ైৌআஆোਠਭजொમळಘஷഏি/ચਾ“ਯ$ଐീवऩ८ઢఛఎেథഠ[औಳରथୃൈಝnজਥऑଷੱल೯wओଵढ़மവरడఊbೖਈૃपdêଉఐ;ै ఢ ઔકচ৩ਊൾഉਕ೦ಏj€:ਦಗાളੁशफുழൻಊगફఏఅ?णറഘಞ४ಡಫଠ್ড೨ൊঞमਂસૉॉઅരஙલঘନ്ఠॄvઋృষऎகೕଘઆఞലେূஊఉૈദఫఈदকज़!ధઠవଞறಟਖ਼ਫ਼ইਢഡঠஃஸୂटঅହఆளోईৃಜ॥(ઈଏੀഈक્গ ಚಢഹೃिஏಯyশேଡೋੈਣડఃഷഇਸ਼நখಋோনૐਏgहৗೈृவੰଜग़ੋ୍)ൌరమൺংञਓપయധஇോ५ઃಲళঊತॽന…ঙಭाಇउਅଶরઓି্ூমuపബ\ૌଟबਆुಕଫதছ३దਿದణஐௌ்ৈqఘலહಾ०ಛঐிওऋి৮ेਨଇүଧഞಶéਚ्৫ୋశఓદঈୀ৪ପüুങਗ਼ઑજথఖঝಐऽਰାആജीઇੜ]आବଡ଼ഫಥుಎણଃયछஅેஹംଢબoদഎగଭాേഅঋসഐಃzਡಬਝன–உಖಉഃযସୈೆకॐನഋয়సசଙড়ୱऒऐઐतଂாতરâèनಧ॑டঔभர”జ৷ਫଣଚଦधघೌୌਉ'''
|
7 |
+
|
8 |
labels= [i for i in labels]
|
9 |
+
labels_en= [i for i in labels_en]
|
10 |
|
11 |
text_labels = [i for i in labels]
|
12 |
text_labels+='<S>','<E>','<PAD>'
|
13 |
|
14 |
+
text_labels_en = [i for i in labels_en]
|
15 |
+
text_labels_en+='<S>','<E>','<PAD>'
|
16 |
+
|
17 |
code_labels= [str(i) for i in range(config.semantic_model_centroids)]
|
18 |
labels+=code_labels
|
19 |
code_labels+='<SST>','<EST>','<PAD>'
|
|
|
27 |
text_enc = {j:i for i,j in enumerate(text_labels)}
|
28 |
text_dec = {i:j for i,j in enumerate(text_labels)}
|
29 |
|
30 |
+
|
31 |
+
text_enc_en = {j:i for i,j in enumerate(text_labels_en)}
|
32 |
+
text_dec_en = {i:j for i,j in enumerate(text_labels_en)}
|
33 |
+
|
34 |
#code encdec
|
35 |
code_enc = {j:i for i,j in enumerate(code_labels)}
|
36 |
code_dec = {i:j for i,j in enumerate(code_labels)}
|
maha_tts/utils/audio.py
CHANGED
@@ -6,8 +6,8 @@ from scipy.signal import get_window
|
|
6 |
from scipy.io.wavfile import read
|
7 |
from maha_tts.config import config
|
8 |
|
9 |
-
TACOTRON_MEL_MAX = 2.
|
10 |
-
TACOTRON_MEL_MIN = -11.
|
11 |
|
12 |
|
13 |
def denormalize_tacotron_mel(norm_mel):
|
|
|
6 |
from scipy.io.wavfile import read
|
7 |
from maha_tts.config import config
|
8 |
|
9 |
+
TACOTRON_MEL_MAX = 2.4
|
10 |
+
TACOTRON_MEL_MIN = -11.5130
|
11 |
|
12 |
|
13 |
def denormalize_tacotron_mel(norm_mel):
|
ref_clips/2971_4275_000003_000007.wav
DELETED
Binary file (392 kB)
|
|
ref_clips/2971_4275_000020_000001.wav
DELETED
Binary file (386 kB)
|
|
ref_clips/2971_4275_000023_000010.wav
DELETED
Binary file (435 kB)
|
|
ref_clips/2971_4275_000049_000000.wav
DELETED
Binary file (366 kB)
|
|
ref_clips/2971_4275_000049_000004.wav
DELETED
Binary file (321 kB)
|
|
ref_clips/2971_4275_000050_000000.wav
DELETED
Binary file (385 kB)
|
|
requirements.txt
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotated-types==0.6.0
|
2 |
+
audioread==3.0.1
|
3 |
+
certifi==2023.11.17
|
4 |
+
cffi==1.16.0
|
5 |
+
charset-normalizer==3.3.2
|
6 |
+
decorator==5.1.1
|
7 |
+
einops==0.7.0
|
8 |
+
filelock==3.13.1
|
9 |
+
fsspec==2023.10.0
|
10 |
+
huggingface-hub==0.19.4
|
11 |
+
idna==3.4
|
12 |
+
inflect==7.0.0
|
13 |
+
Jinja2==3.1.2
|
14 |
+
joblib==1.3.2
|
15 |
+
lazy_loader==0.3
|
16 |
+
librosa==0.10.1
|
17 |
+
llvmlite==0.41.1
|
18 |
+
MarkupSafe==2.1.3
|
19 |
+
mpmath==1.3.0
|
20 |
+
msgpack==1.0.7
|
21 |
+
networkx==3.2.1
|
22 |
+
numba==0.58.1
|
23 |
+
numpy==1.26.2
|
24 |
+
packaging==23.2
|
25 |
+
platformdirs==4.0.0
|
26 |
+
pooch==1.8.0
|
27 |
+
pycparser==2.21
|
28 |
+
pydantic==2.5.1
|
29 |
+
pydantic_core==2.14.3
|
30 |
+
PyYAML==6.0.1
|
31 |
+
regex==2023.10.3
|
32 |
+
requests==2.31.0
|
33 |
+
safetensors==0.4.0
|
34 |
+
scikit-learn==1.3.2
|
35 |
+
scipy==1.11.3
|
36 |
+
soundfile==0.12.1
|
37 |
+
soxr==0.3.7
|
38 |
+
sympy==1.12
|
39 |
+
threadpoolctl==3.2.0
|
40 |
+
tokenizers==0.15.0
|
41 |
+
torch==2.1.1
|
42 |
+
tqdm==4.66.1
|
43 |
+
transformers==4.35.2
|
44 |
+
typing_extensions==4.8.0
|
45 |
+
Unidecode==1.3.7
|
46 |
+
urllib3==2.1.0
|
setup.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup, find_packages
|
3 |
+
|
4 |
+
__version__ = '1.0.0'
|
5 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
# requirements = open(os.path.join(cwd, "requirements.txt"), "r").readlines()
|
7 |
+
|
8 |
+
setup(
|
9 |
+
name='maha_tts',
|
10 |
+
version=__version__,
|
11 |
+
|
12 |
+
url='https://github.com/dubverse-ai/MahaTTS/tree/main',
|
13 |
+
author='Dubverse AI',
|
14 |
+
author_email='[email protected]',
|
15 |
+
install_requires = [
|
16 |
+
'einops',
|
17 |
+
'transformers',
|
18 |
+
'unidecode',
|
19 |
+
'inflect'
|
20 |
+
],
|
21 |
+
packages=find_packages(),
|
22 |
+
py_modules=['maha_tts'],
|
23 |
+
)
|
tts.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1 |
import torch,glob
|
2 |
-
from maha_tts import load_diffuser,load_models,infer_tts
|
3 |
from scipy.io.wavfile import write
|
4 |
|
5 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
print('Using:',device)
|
7 |
text = 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.'
|
8 |
-
|
|
|
|
|
9 |
# print(len(ref_clips))
|
10 |
|
11 |
# diffuser = load_diffuser()
|
12 |
-
diff_model,ts_model,vocoder,diffuser = load_models('Smolie',device)
|
13 |
-
audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder)
|
14 |
write('test.wav',sr,audio)
|
|
|
1 |
import torch,glob
|
2 |
+
from maha_tts import load_diffuser,load_models,infer_tts,config
|
3 |
from scipy.io.wavfile import write
|
4 |
|
5 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
print('Using:',device)
|
7 |
text = 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.'
|
8 |
+
langauge = 'english'
|
9 |
+
language = torch.tensor(config.lang_index[langauge]).to(device).unsqueeze(0)
|
10 |
+
ref_clips = glob.glob('models/Smolie-en/ref_clips/part0_1_1/*.wav')
|
11 |
# print(len(ref_clips))
|
12 |
|
13 |
# diffuser = load_diffuser()
|
14 |
+
diff_model,ts_model,vocoder,diffuser = load_models('Smolie-in',device)
|
15 |
+
audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder,language)
|
16 |
write('test.wav',sr,audio)
|