File size: 13,734 Bytes
390d94d
22a7887
 
 
390d94d
22a7887
 
 
 
 
 
 
 
 
390d94d
22a7887
 
 
390d94d
 
22a7887
 
 
 
 
 
 
 
390d94d
 
 
 
 
 
22a7887
 
390d94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22a7887
390d94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22a7887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390d94d
 
 
 
22a7887
 
 
 
 
 
 
 
 
390d94d
22a7887
 
 
 
 
390d94d
22a7887
 
390d94d
22a7887
 
 
 
 
 
 
 
 
 
 
390d94d
22a7887
 
390d94d
22a7887
 
 
 
 
390d94d
 
 
22a7887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390d94d
22a7887
390d94d
 
 
 
 
22a7887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390d94d
22a7887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390d94d
22a7887
 
 
 
 
 
 
390d94d
22a7887
 
 
 
390d94d
22a7887
 
390d94d
22a7887
 
 
 
 
390d94d
22a7887
 
 
 
 
 
 
 
 
 
 
 
 
390d94d
22a7887
 
390d94d
22a7887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import torch,glob,os,requests
import numpy as np
import torch.nn.functional as F

from tqdm import tqdm
from librosa.filters import mel as librosa_mel_fn
from scipy.io.wavfile import write
from scipy.special import softmax
from maha_tts.models.diff_model import load_diff_model
from maha_tts.models.autoregressive import load_TS_model
from maha_tts.models.vocoder import load_vocoder_model,infer_wav
from maha_tts.utils.audio import denormalize_tacotron_mel,normalize_tacotron_mel,load_wav_to_torch,dynamic_range_compression
from maha_tts.utils.stft import STFT
from maha_tts.utils.diffusion import SpacedDiffusion,get_named_beta_schedule,space_timesteps
from maha_tts.text.symbols import labels,text_labels,text_labels_en,code_labels,text_enc,text_dec,code_enc,code_dec,text_enc_en,text_dec_en
from maha_tts.text.cleaners import  english_cleaners
from maha_tts.config import config

DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'maha_tts', 'models')
DEFAULT_MODELS_DIR = '/Users/jaskaransingh/Desktop/MahaTTS/models/'
stft_fn = STFT(config.filter_length, config.hop_length, config.win_length)

mel_basis = librosa_mel_fn(
        sr=config.sampling_rate, n_fft=config.filter_length, n_mels=config.n_mel_channels, fmin=config.mel_fmin, fmax=config.mel_fmax)

mel_basis = torch.from_numpy(mel_basis).float()

model_dirs= {
    'Smolie':['https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt',
                'https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/smolie/T2S/t2s_best.pt'],
    'Smolie-en':[''],
    'Smolie-in':[''],
    'hifigan':['https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/hifigan/g_02500000',
                'https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/hifigan/config.json']
}

def download_file(url, filename):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))

    # Check if the response was successful (status code 200)
    response.raise_for_status()

    with open(filename, 'wb') as file, tqdm(
        desc=filename,
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            # Write data to the file
            file.write(data)
            # Update the progress bar
            bar.update(len(data))

    print(f"Download complete: {filename}\n")

def download_model(name):
    print('Downloading ',name," ....")
    checkpoint_diff = os.path.join(DEFAULT_MODELS_DIR,name,'s2a_latest.pt')
    checkpoint_ts = os.path.join(DEFAULT_MODELS_DIR,name,'t2s_best.pt')
    checkpoint_voco = os.path.join(DEFAULT_MODELS_DIR,'hifigan','g_02500000')
    voco_config_path = os.path.join(DEFAULT_MODELS_DIR,'hifigan','config.json')

    os.makedirs(os.path.join(DEFAULT_MODELS_DIR,name),exist_ok=True)
        
    if name == 'hifigan':
        download_file(model_dirs[name][0],checkpoint_voco)
        download_file(model_dirs[name][1],voco_config_path)
    
    else:
        download_file(model_dirs[name][0],checkpoint_diff)
        download_file(model_dirs[name][1],checkpoint_ts)

def load_models(name,device=torch.device('cpu')):
    '''
    Load pre-trained models for different components of a text-to-speech system.

    Args:
    device (str): The target device for model loading (e.g., 'cpu' or 'cuda').
    checkpoint_diff (str): File path to the pre-trained model checkpoint for the diffusion model.
    checkpoint_ts (str): File path to the pre-trained model checkpoint for the text-to-semantic model.
    checkpoint_voco (str): File path to the pre-trained model checkpoint for the vocoder model.
    voco_config_path (str): File path to the configuration file for the vocoder model.

    Returns:
    diff_model (object): Loaded diffusion model for semantic-to-acoustic tokens.
    ts_model (object): Loaded text-to-semantic model for converting text-to-semantic tokens.
    vocoder (object): Loaded vocoder model for generating waveform from acoustic tokens.
    diffuser (object): Configured diffuser object for use in the diffusion model.
    '''

    assert name in model_dirs, "no model name "+name

    checkpoint_diff = os.path.join(DEFAULT_MODELS_DIR,name,'s2a_latest.pt')
    checkpoint_ts = os.path.join(DEFAULT_MODELS_DIR,name,'t2s_best.pt')
    checkpoint_voco = os.path.join(DEFAULT_MODELS_DIR,'hifigan','g_02500000')
    voco_config_path = os.path.join(DEFAULT_MODELS_DIR,'hifigan','config.json')
    
    # for i in [checkpoint_diff,checkpoint_ts,checkpoint_voco,voco_config_path]:
    if not os.path.exists(checkpoint_diff) or not os.path.exists(checkpoint_ts):
        download_model(name)
    
    if not os.path.exists(checkpoint_voco) or not os.path.exists(voco_config_path):
        download_model('hifigan')

    diff_model = load_diff_model(checkpoint_diff,device)
    ts_model = load_TS_model(checkpoint_ts,device,name)
    vocoder = load_vocoder_model(voco_config_path,checkpoint_voco,device)
    diffuser = load_diffuser()

    return diff_model,ts_model,vocoder,diffuser

def infer_mel(model,timeshape,code,ref_mel,diffuser,temperature=1.0):
    device = next(model.parameters()).device
    code = code.to(device)
    ref_mel =ref_mel.to(device)
    output_shape = (1,80,timeshape)
    noise = torch.randn(output_shape, device=code.device) * temperature
    mel = diffuser.p_sample_loop(model, output_shape, noise=noise,
                                      model_kwargs={'code_emb': code,'ref_clips':ref_mel},
                                     progress=True)
    return denormalize_tacotron_mel(mel)

def generate_semantic_tokens(
    text,
    model,
    ref_mels,
    language=None,
    temp = 0.7,
    top_p= None,
    top_k= 1,
    n_tot_steps = 1000,
    device = None
    ):
    semb = []
    with torch.no_grad():
        for n in tqdm(range(n_tot_steps)):
            x = get_inputs(text,semb,ref_mels,device,model.name)
            _,result = model(**x,language=language)
            relevant_logits = result[0,:,-1]
            if top_p is not None:
                # faster to convert to numpy
                original_device = relevant_logits.device
                relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
                sorted_indices = np.argsort(relevant_logits)[::-1]
                sorted_logits = relevant_logits[sorted_indices]
                cumulative_probs = np.cumsum(softmax(sorted_logits))
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
                sorted_indices_to_remove[0] = False
                relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
                relevant_logits = torch.from_numpy(relevant_logits)
                relevant_logits = relevant_logits.to(original_device)

            if top_k is not None:
                v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
                relevant_logits[relevant_logits < v[-1]] = -float("Inf")

            probs = F.softmax(relevant_logits / temp, dim=-1)
            item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
            semb.append(str(code_dec[item_next.item()]))
            if semb[-1] == '<EST>' or semb[-1] == '<PAD>':
                break

            del relevant_logits, probs, item_next

    semb = torch.tensor([int(i) for i in semb[:-1]])
    return semb,result

def get_inputs(text,semb=[],ref_mels=[],device=torch.device('cpu'),name = 'Smolie-in'):
  text = text.lower()
  if name=='Smolie-en':
    text_ids=[text_enc_en['<S>']]+[text_enc_en[i] for i in text.strip()]+[text_enc_en['<E>']]
  else:
    text_ids=[text_enc['<S>']]+[text_enc[i] for i in text.strip()]+[text_enc['<E>']]
    
  semb_ids=[code_enc['<SST>']]+[code_enc[i] for i in semb]#+[tok_enc['<EST>']]

  input_ids = text_ids+semb_ids
  # pad_length = config.t2s_position-(len(text_ids)+len(semb_ids))

  token_type_ids = [0]*len(text_ids)+[1]*len(semb_ids)#+[0]*pad_length
  positional_ids = [i for i in range(len(text_ids))]+[i for i in range(len(semb_ids))]#+[0]*pad_length
  # labels = [-100]*len(text_ids)+semb_ids+[-100]*pad_length
  attention_mask = [1]*len(input_ids)#+[0]*pad_length
  # input_ids += [tok_enc['<PAD>']]*pad_length
  return {'text_ids':torch.tensor(text_ids).unsqueeze(0).to(device),'codes_ids':torch.tensor(semb_ids).unsqueeze(0).to(device),'ref_clips':normalize_tacotron_mel(ref_mels).to(device)}

def get_ref_mels(ref_clips):
    ref_mels = []
    for i in ref_clips:
        ref_mels.append(get_mel(i)[0][:,:500])
    
    ref_mels_padded = (torch.randn((len(ref_mels), 80, 500)))*1e-8
    for i,mel in enumerate(ref_mels):
        ref_mels_padded[i, :, :mel.size(1)] = mel
    return ref_mels_padded.unsqueeze(0)

def get_mel(filepath):
    audio, sampling_rate = load_wav_to_torch(filepath)
    audio_norm = audio / config.MAX_WAV_VALUE
    audio_norm = audio_norm.unsqueeze(0)
    y = torch.autograd.Variable(audio_norm, requires_grad=False)

    assert(torch.min(y.data) >= -1)
    assert(torch.max(y.data) <= 1)
    magnitudes, phases = stft_fn.transform(y)
    magnitudes = magnitudes.data
    mel_output = torch.matmul(mel_basis, magnitudes)
    mel_output = dynamic_range_compression(mel_output)
    melspec = torch.squeeze(mel_output, 0)
    energy = torch.norm(magnitudes, dim=1).squeeze(0)
    return melspec,list(energy)

def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder,language=None):
    '''
    Generate audio from the given text using a text-to-speech (TTS) pipeline.

    Args:
    text (str): The input text to be synthesized into speech.
    ref_clips (list): A list of paths to reference audio clips, preferably more than 3 clips.
    diffuser (object): A diffusion object used for denoising and guidance in the diffusion model. It should be obtained using load_diffuser.
    diff_model: diffusion model for semantic-to-acoustic tokens.
    ts_model: text-to-semantic model for converting text-to-semantic tokens.
    vocoder: vocoder model for generating waveform from acoustic tokens.

    Returns:
    audio (numpy.ndarray): Generated audio waveform.
    sampling_rate (int): Sampling rate of the generated audio.

    Description:
    The `infer_tts` function takes input text and reference audio clips, and processes them through a TTS pipeline.
    It first performs text preprocessing and generates semantic tokens using the specified text synthesis model.
    Then, it infers mel-spectrogram features using the diffusion model and the provided diffuser.
    Finally, it generates audio from the mel-spectrogram using the vocoder.

    Note: The function requires properly configured diff_model, ts_model, and vocoder objects for successful TTS.

    Example usage:
    audio, sampling_rate = infer_tts("Hello, how are you?", ref_clips, diffuser, diff_model, ts_model, vocoder)
    '''
    device = next(ts_model.parameters()).device
    text = english_cleaners(text)
    ref_mels = get_ref_mels(ref_clips)
    with torch.no_grad():
        sem_tok,_ = generate_semantic_tokens(
                        text,
                        ts_model,
                        ref_mels,
                        language,
                        temp = 0.7,
                        top_p= 0.8,
                        top_k= 5,
                        n_tot_steps = 1000,
                        device = device
                    )
        mel = infer_mel(diff_model,int(((sem_tok.shape[-1] * 320 / 16000) * 22050/256)+1),sem_tok.unsqueeze(0) + 1,
                        normalize_tacotron_mel(ref_mels),diffuser,temperature=0.5)

        audio = infer_wav(mel,vocoder)
    
    return audio,config.sampling_rate

def load_diffuser(timesteps = 100, guidance=3):
    '''
    Load and configure a diffuser for denoising and guidance in the diffusion model.

    Args:
    timesteps (int): Number of denoising steps out of 1000. Default is 100.
    guidance (int): Conditioning-free guidance parameter. Default is 3.

    Returns:
    diffuser (object): Configured diffuser object for use in the diffusion model.

    Description:
    The `load_diffuser` function initializes a diffuser with specific settings for denoising and guidance.
    '''
    betas = get_named_beta_schedule('linear',config.sa_timesteps_max)
    diffuser = SpacedDiffusion(use_timesteps=space_timesteps(1000, [timesteps]), model_mean_type='epsilon',
                        model_var_type='learned_range', loss_type='rescaled_mse', betas=betas,
                        conditioning_free=True, conditioning_free_k=guidance)
    diffuser.training=False
    return diffuser

if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    text = 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.'
    ref_clips = glob.glob('/Users/jaskaransingh/Desktop/maha_tts/ref_clips/*.wav')

    checkpoint_diff = 'maha_tts/pretrained_models/S2A/s2a_latest.pt'
    checkpoint_ts = 'maha_tts/pretrained_models/T2S/t2s_best.pt'
    checkpoint_voco = 'maha_tts/pretrained_models/hifigan/g_02500000'
    voco_config_path = 'maha_tts/pretrained_models/hifigan/config.json'

    diffuser = load_diffuser()
    diff_model,ts_model,vocoder = load_models(device,checkpoint_diff,checkpoint_ts,checkpoint_voco,voco_config_path)
    audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder)
    write('test.wav',sr,audio)